Merge commit for internal changes

This commit is contained in:
Erich Elsen 2017-12-20 11:01:08 -08:00
commit 15283d90e7
102 changed files with 4784 additions and 595 deletions

View File

@ -42,6 +42,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
], ],
}), }),
) )
@ -73,6 +74,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#endif #endif
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -2618,4 +2619,54 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
output_values, target_names, nullptr, status); output_values, target_names, nullptr, status);
} }
TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
tensorflow::OpList op_list;
if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
status->status = InvalidArgument("Unparseable OpList");
return nullptr;
}
status->status = Status::OK();
return new TF_ApiDefMap(op_list);
}
void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
size_t text_len, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"ApiDefMap is not supported in Android.");
#else
mutex_lock l(api_def_map->lock);
if (api_def_map->update_docs_called) {
status->status = FailedPrecondition(
"TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
"called.");
return;
}
string api_def_text(text, text_len);
status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
#endif // __ANDROID__
}
TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
size_t name_len, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"ApiDefMap is not supported in Android.");
return nullptr;
#else
mutex_lock l(api_def_map->lock);
if (!api_def_map->update_docs_called) {
api_def_map->api_def_map.UpdateDocs();
api_def_map->update_docs_called = true;
}
string name_str(name, name_len);
const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(*api_def, ret);
return ret;
#endif // __ANDROID__
}
} // end extern "C" } // end extern "C"

View File

@ -1518,6 +1518,49 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
// in this address space. // in this address space.
TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList();
// TF_ApiDefMap encapsulates a collection of API definitions for an operation.
//
// This object maps the name of a TensorFlow operation to a description of the
// API to generate for it, as defined by the ApiDef protocol buffer (
// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto)
//
// The ApiDef messages are typically used to generate convenience wrapper
// functions for TensorFlow operations in various language bindings.
typedef struct TF_ApiDefMap TF_ApiDefMap;
// Creates a new TF_ApiDefMap instance.
//
// Params:
// op_list_buffer - TF_Buffer instance containing serialized OpList
// protocol buffer. (See
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
// for the OpList proto definition).
// status - Set to OK on success and an appropriate error on failure.
TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer,
TF_Status* status);
// Deallocates a TF_ApiDefMap.
TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap);
// Add ApiDefs to the map.
//
// `text` corresponds to a text representation of an ApiDefs protocol message.
// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto).
//
// The provided ApiDefs will be merged with existing ones in the map, with
// precedence given to the newly added version in case of conflicts with
// previous calls to TF_ApiDefMapPut.
TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map,
const char* text, size_t text_len,
TF_Status* status);
// Returns a serialized ApiDef protocol buffer for the TensorFlow operation
// named `name`.
TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map,
const char* name,
size_t name_len,
TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -24,6 +24,9 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#ifndef __ANDROID__
#include "tensorflow/core/framework/op_gen_lib.h"
#endif
#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
@ -158,6 +161,22 @@ struct TF_Function {
tensorflow::FunctionDef fdef; tensorflow::FunctionDef fdef;
}; };
struct TF_ApiDefMap {
explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
:
#ifndef __ANDROID__
api_def_map(op_list),
#endif
update_docs_called(false) {
}
#ifndef __ANDROID__
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
#endif
bool update_docs_called GUARDED_BY(lock);
tensorflow::mutex lock;
};
namespace tensorflow { namespace tensorflow {
class TensorCApi { class TensorCApi {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h"
@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) {
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
} }
TEST(TestApiDef, TestCreateApiDef) {
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TF_Buffer op_list_buf = TF_GetOpList(lib);
status = TF_NewStatus();
auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
string op_name = "TestCApi";
status = TF_NewStatus();
auto* api_def_buf =
TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
tensorflow::ApiDef api_def;
EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
EXPECT_EQ(op_name, api_def.graph_op_name());
EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
TF_DeleteLibraryHandle(lib);
}
TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TF_Buffer op_list_buf = TF_GetOpList(lib);
status = TF_NewStatus();
auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
string api_def_overwrites = R"(op: <
graph_op_name: "TestCApi"
summary: "New summary"
>
)";
status = TF_NewStatus();
TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
api_def_overwrites.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
string op_name = "TestCApi";
status = TF_NewStatus();
auto* api_def_buf =
TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
tensorflow::ApiDef api_def;
EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
EXPECT_EQ(op_name, api_def.graph_op_name());
EXPECT_EQ("New summary", api_def.summary());
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
TF_DeleteLibraryHandle(lib);
}
#undef EXPECT_TF_META #undef EXPECT_TF_META
} // namespace } // namespace

View File

@ -33,7 +33,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
], ],
}), }) + ["//tensorflow/core:gpu_runtime"],
) )
tf_cuda_library( tf_cuda_library(
@ -55,6 +55,10 @@ tf_cuda_library(
tf_cuda_cc_test( tf_cuda_cc_test(
name = "c_api_test", name = "c_api_test",
srcs = ["c_api_test.cc"], srcs = ["c_api_test.cc"],
tags = [
"guitar",
"multi_gpu",
],
deps = [ deps = [
":c_api", ":c_api",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h" #include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -167,18 +168,6 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
if (is_same_device) { if (is_same_device) {
return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
} }
const bool src_cpu = IsCPU(srcd);
if (src_cpu == dst_cpu) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"TFE_TensorHandleCopyToDevice requires either the source "
"TFE_TensorHandle be on or the destination device be on CPU "
"or be the same (they are ",
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
.c_str());
return nullptr;
}
tensorflow::Tensor* src = &(h->t); tensorflow::Tensor* src = &(h->t);
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) { if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
TF_SetStatus( TF_SetStatus(
@ -189,26 +178,19 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
.c_str()); .c_str());
return nullptr; return nullptr;
} }
if (src_cpu) { tensorflow::Tensor dst(dstd->GetAllocator(tensorflow::AllocatorAttributes()),
tensorflow::Tensor dst( src->dtype(), src->shape());
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
src->shape());
if (src->shape().num_elements() == 0) { if (src->shape().num_elements() == 0) {
return new TFE_TensorHandle(dst, dstd); return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
} }
tensorflow::Notification n; tensorflow::DeviceContext* src_device_context = nullptr;
dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice( if (!IsCPU(srcd)) {
src, dstd, &dst, [status, &n](const tensorflow::Status& s) { src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
status->status = s; }
n.Notify(); tensorflow::DeviceContext* dst_device_context = nullptr;
}); if (!dst_cpu) {
n.WaitForNotification(); dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
: nullptr;
} }
CHECK(dst_cpu);
tensorflow::Tensor dst(src->dtype(), src->shape());
tensorflow::Notification n;
// TODO(ashankar): The Sync() call below may be more aggressive than // TODO(ashankar): The Sync() call below may be more aggressive than
// necessary. It is based on knowledge of implementation details - that // necessary. It is based on knowledge of implementation details - that
// GPU devices are implemented using 3 streams - one for host->device copies, // GPU devices are implemented using 3 streams - one for host->device copies,
@ -217,15 +199,17 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
// but more than necessary (since it waits for operations that might have // but more than necessary (since it waits for operations that might have
// nothing to do with this tensor to complete). // nothing to do with this tensor to complete).
status->status = srcd->Sync(); status->status = srcd->Sync();
if (!status->status.ok()) return nullptr; tensorflow::Notification n;
srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU( tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
src, "IGNORE_MY_TENSOR_NAME", srcd, &dst, srcd, dstd, tensorflow::AllocatorAttributes(),
tensorflow::AllocatorAttributes(), src, &dst,
[status, &n](const tensorflow::Status& s) { [status, &n](const tensorflow::Status& s) {
status->status = s; status->status = s;
n.Notify(); n.Notify();
}); });
n.WaitForNotification(); n.WaitForNotification();
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr) return (TF_GetCode(status) == TF_OK)
? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
: nullptr; : nullptr;
} }

View File

@ -216,6 +216,64 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
} }
TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
const char* kCPUDevice = "CPU:0";
if (num_devices < 3) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
return;
}
const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
const string gpu_2_name(TF_DeviceListName(devices, 2, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice(
hdevice, ctx, gpu_2_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_DeleteTensorHandle(hdevice);
// Copy back to CPU
TFE_TensorHandle* hcopy =
TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_DeleteTensorHandle(hdevice2);
// Ensure that the contents are the same!
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
TFE_DeleteTensorHandle(hcopy);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy));
EXPECT_EQ(
0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)));
TF_DeleteTensor(tcopy);
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, TensorHandleSilentCopy) { TEST(CAPI, TensorHandleSilentCopy) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);

View File

@ -248,6 +248,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",

File diff suppressed because it is too large Load Diff

View File

@ -48,6 +48,16 @@ typedef std::function<Status(
// 'group_attribute' must be a string valued-attribute that names the new // 'group_attribute' must be a string valued-attribute that names the new
// functions to introduce. // functions to introduce.
// //
// 'outside_compilation_attribute' must be a string-valued attribute that is
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
// cluster within the subgraph. A cluster is formed from the set of nodes with
// the same value of outside_compilation_subgraph and group_attribute. The nodes
// in an outside_compilation cluster are left in the original graph. Edges
// crossing from the subgraph to an outside_compilation cluster nested in the
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
// crossing from an outside_compilation cluster into its enclosing subgraph are
// lifted into a SendFromHost/RecvFromHost pair of nodes.
//
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion. // function conversion.
// //
@ -64,10 +74,10 @@ typedef std::function<Status(
// dep from B. Originally D must run after C, post-transformation this // dep from B. Originally D must run after C, post-transformation this
// dependency is lost. // dependency is lost.
Status EncapsulateSubgraphsInFunctions( Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in, string group_attribute, string outside_compilation_attribute,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out, bool parallel_checking, bool reuse_existing_functions,
FunctionLibraryDefinition* library); std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate // The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. // subgraphs pass and that should in turn be compiled via _XlaLaunch operators.

View File

@ -36,7 +36,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
if (diff) { if (diff) {
*diff = strings::StrCat("Definition mismatch for function ", *diff = strings::StrCat("Definition mismatch for function ",
a.signature().name(), ", expected:\n", a.signature().name(), ", expected:\n",
a.DebugString()); a.DebugString(), "\ngot:\n", b.DebugString());
} }
return false; return false;
} }
@ -82,6 +82,24 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
<< diff << "\nActual: " << actual.DebugString(); \ << diff << "\nActual: " << actual.DebugString(); \
} while (false) } while (false)
// TODO(misard): remove these fake registrations once there are real Ops to be
// compiled.
REGISTER_OP("_XlaSendToHost")
.Input("input: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaRecvFromHost")
.Output("output: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaSendFromHost")
.Input("input: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaRecvAtHost")
.Output("output: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("InputTest").Output("o: float"); REGISTER_OP("InputTest").Output("o: float");
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
@ -98,10 +116,32 @@ REGISTER_OP("AddNLikeTest")
.SetIsCommutative() .SetIsCommutative()
.SetIsAggregate(); .SetIsAggregate();
Node* NoOp(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("NoOp", opts);
}
Node* Input(const GraphDefBuilder::Options& opts) { Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTest", opts); return ops::SourceOp("InputTest", opts);
} }
Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
}
Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", std::move(a), opts); return ops::UnaryOp("UnaryTest", std::move(a), opts);
} }
@ -145,7 +185,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
if (!s.ok()) return s; if (!s.ok()) return s;
std::unique_ptr<Graph> graph_out; std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph,
/*rewrite_subgraph_fn=*/{}, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/false, /*parallel_checking=*/false,
/*reuse_existing_functions=*/false, /*reuse_existing_functions=*/false,
@ -178,6 +218,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) {
FunctionDefLibrary library_out = library_in; FunctionDefLibrary library_out = library_in;
TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out)); TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out); TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
} }
@ -230,7 +271,6 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
} }
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
} }
@ -342,9 +382,9 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph; std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, "_cluster", "_outside", graph_before_encapsulation,
/*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false,
&library)); /*reuse_existing_functions=*/false, &graph, &library));
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"}; std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
EXPECT_EQ(expected_nodes, GraphNodes(*graph)); EXPECT_EQ(expected_nodes, GraphNodes(*graph));
@ -374,9 +414,9 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph; std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, "_cluster", "_outside", graph_before_encapsulation,
/*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true,
&library)); /*reuse_existing_functions=*/false, &graph, &library));
std::vector<string> expected_nodes = { std::vector<string> expected_nodes = {
"add1", "add2", "cluster1", "cluster1_parallel_check/_0", "add1", "add2", "cluster1", "cluster1_parallel_check/_0",
@ -432,7 +472,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0; int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before, "_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/ /*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr, [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation, std::vector<int>* input_permutation,
@ -477,7 +517,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
FunctionLibraryDefinition library(OpRegistry::Global(), {}); FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0; int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before, "_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/ /*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr, [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation, std::vector<int>* input_permutation,
@ -502,5 +542,678 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
EXPECT_EQ(1, guaranteed_consts); EXPECT_EQ(1, guaranteed_consts);
} }
// Test with one function to transform and one outside_compilation cluster.
TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d = Binary(b, c,
b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"C:o:0", "c:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"c"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts().WithName("E").WithControlInputs({recv, b}));
Node* send = SendFromHost({e}, {DT_FLOAT},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* s = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one function to transform and two outside_compilation clusters.
TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g = Binary(e, f,
b1.opts()
.WithName("G")
.WithControlInputs({e, f})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* h = Binary(d, e,
b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
Binary(g, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
{{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O2_send"},
"_XlaSendToHost",
{"D:o:0", "F:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"F"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"C:o:0", "D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"D"}},
{{"outside_compilation_O2_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O2_send"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"i_0_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv1 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
Node* send1 = SendFromHost({e}, {DT_FLOAT},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* recv2 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O2_recv"));
Node* g = Binary(e, ops::NodeOut(recv2, 1),
b2.opts().WithName("G").WithControlInputs({recv2, e}));
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
Node* send2 = SendFromHost(
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
Node* s = NoOp(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, send1, recv2, send2}));
Binary(g, call, b2.opts().WithName("J").WithControlInput(s));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform, each with one outside_compilation
// cluster.
TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g = Binary(e, f,
b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
"_encapsulate", "F2"));
Node* h = Binary(d, g,
b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
Node* i =
Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(g, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"f_0_retval:float", "d_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"C:o:0", "D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"D"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"e_0_arg:float", "f_0_arg:float"},
{"g_0_retval:float", "i_0_retval:float"}, {},
{
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"G:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
Node* send1 = SendFromHost({e}, {DT_FLOAT},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Node* recv2 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
b2.opts().WithName("H").WithControlInput(s1));
Node* send2 = SendFromHost(
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(e).Input(call1);
Node* call2 = b2.opts()
.WithControlInputs({s1, e, call1})
.FinalizeBuilder(&node_builder2);
Node* s2 = NoOp(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}));
Binary(call2, ops::NodeOut(call2, 1),
b2.opts().WithName("J").WithControlInput(s2));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no inputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f =
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Unary(f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts().WithName("E"));
Node* send1 = SendFromHost(
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1));
Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no data inputs but has a
// control input from the compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithControlInput(d)
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f =
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Unary(f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({})}},
{"D"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 =
RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
Node* send1 = SendFromHost(
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1));
Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no data outputs but has a
// control output to the compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
Node* send1 = SendFromHost({}, {},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts().WithName("E"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(e, call1, b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -41,6 +41,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
const char* const kXlaClusterAttr = "_XlaCluster"; const char* const kXlaClusterAttr = "_XlaCluster";
const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
namespace { namespace {

View File

@ -28,6 +28,10 @@ namespace tensorflow {
// encapsulate subgraphs pass. // encapsulate subgraphs pass.
extern const char* const kXlaClusterAttr; extern const char* const kXlaClusterAttr;
// The attribute that marks nodes in a cluster to be placed outside the xla
// compilation by the encapsulate subgraphs pass.
extern const char* const kXlaOutsideCompilationAttr;
// Pass that marks a subset of operators in the graph with attribute // Pass that marks a subset of operators in the graph with attribute
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass. // _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
class MarkForCompilationPass : public GraphOptimizationPass { class MarkForCompilationPass : public GraphOptimizationPass {

View File

@ -270,8 +270,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return false; return false;
} }
/* static */ tensorflow::gtl::ArraySlice<const int64> /* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
LayoutUtil::PaddedDimensions(const Shape& shape) { const Shape& shape) {
CHECK(IsDense(shape)); CHECK(IsDense(shape));
return AsInt64Slice(shape.layout().padded_dimensions()); return AsInt64Slice(shape.layout().padded_dimensions());
} }

View File

@ -96,7 +96,7 @@ class LayoutUtil {
// Returns the padded_dimensions array for the given Shape. Requires that the // Returns the padded_dimensions array for the given Shape. Requires that the
// shape is an array and has a dense layout. // shape is an array and has a dense layout.
static tensorflow::gtl::ArraySlice<const int64> PaddedDimensions( static tensorflow::gtl::ArraySlice<int64> PaddedDimensions(
const Shape& shape); const Shape& shape);
// Returns the given index of the padded_dimensions array for the given Shape. // Returns the given index of the padded_dimensions array for the given Shape.

View File

@ -341,7 +341,7 @@ class Literal {
// Creates a literal of the given shape where each element is `value`. // Creates a literal of the given shape where each element is `value`.
template <typename NativeT> template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithMonotonicDim0MajorLayout( static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value); tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
// Creates a new literal from an array. The variants not ending with // Creates a new literal from an array. The variants not ending with
@ -1233,10 +1233,9 @@ void Literal::PopulateWithValue(NativeT value,
} }
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
Literal::CreateFullWithMonotonicDim0MajorLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) { tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( Shape this_shape = ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions); primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
auto literal = MakeUnique<Literal>(); auto literal = MakeUnique<Literal>();
*literal->mutable_shape() = this_shape; *literal->mutable_shape() = this_shape;

View File

@ -46,6 +46,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",

View File

@ -166,6 +166,18 @@ ComputationDataHandle LocalComputationBuilder::Dot(
return builder_.Dot(lhs, rhs); return builder_.Dot(lhs, rhs);
} }
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
dimension_numbers);
}
ComputationDataHandle LocalComputationBuilder::ConvertElementType( ComputationDataHandle LocalComputationBuilder::ConvertElementType(
const ComputationDataHandle& operand, PrimitiveType new_element_type) { const ComputationDataHandle& operand, PrimitiveType new_element_type) {
return builder_.ConvertElementType(operand, new_element_type); return builder_.ConvertElementType(operand, new_element_type);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla { namespace xla {
@ -113,6 +114,14 @@ class LocalComputationBuilder {
ComputationDataHandle Dot(const ComputationDataHandle& lhs, ComputationDataHandle Dot(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs); const ComputationDataHandle& rhs);
ComputationDataHandle ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
PrimitiveType new_element_type); PrimitiveType new_element_type);

View File

@ -22,18 +22,19 @@ limitations under the License.
// //
// C++ Python // C++ Python
// -------------------------------------+--------------------------------------- // -------------------------------------+---------------------------------------
// ComputationDataHandle <-> long // ComputationDataHandle <-> int
// ArraySlice<int64> <- sequence of long // ArraySlice<int64> <- sequence of int
// ArraySlice<ComputationDataHandle> <- sequence of long // ArraySlice<ComputationDataHandle> <- sequence of int
// Literal <-> (nested tuple of) numpy ndarray // Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray // std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape <-> pair holding (dtype, dimensions) // Shape <-> pair holding (dtype, dimensions)
// std::vector<Shape> <- sequence of shape information pairs // std::vector<Shape> <- sequence of shape information pairs
// PrimitiveType <- int // PrimitiveType <- int
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
// ConvolutionDimensionNumbers proto <- corresponding Python proto
// //
// Arrows indicate whether a conversion only ever occurs in one // Arrows indicate whether a conversion only ever occurs in one
// direction, or whether it is maintained bidirectionally. Also, // direction, or whether it is maintained bidirectionally.
// "long" and "int" denote the Python types so named, not C.
// //
// The Python objects corresponding to C++ Literals have the type: // The Python objects corresponding to C++ Literals have the type:
// //
@ -113,6 +114,27 @@ limitations under the License.
using namespace xla; using namespace xla;
using namespace xla::swig; using namespace xla::swig;
namespace xla {
namespace swig {
bool GetIntAttr(PyObject* o, const char* field, int64* result) {
PyObject* fo = PyObject_GetAttrString(o, field);
if (!fo) {
return false;
}
const int64 value = numpy::PyIntOrPyLongToLong(fo);
if (value == -1 && PyErr_Occurred()) {
Py_DECREF(fo);
return false;
}
Py_DECREF(fo);
*result = value;
return true;
}
}
}
%} %}
// Required to use PyArray_* functions. // Required to use PyArray_* functions.
@ -278,6 +300,189 @@ tensorflow::ImportNumpy();
$1 = static_cast<PrimitiveType>(value); $1 = static_cast<PrimitiveType>(value);
} }
// ArraySlice<pair<int64, in64>>
%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
(std::vector<std::pair<int64, int64> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
temps.reserve(size);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
if (!o) {
return NULL;
}
PyObject* first = PyTuple_GetItem(o, 0);
if (!first) {
Py_DECREF(o);
return NULL;
}
PyObject* first_pyint = numpy::PyNumberToPyInt(first);
if (!first_pyint) {
PyErr_SetString(
PyExc_TypeError,
"First pair item cannot be converted to int");
Py_DECREF(o);
return NULL;
}
PyObject* second = PyTuple_GetItem(o, 1);
if (!second) {
Py_DECREF(o);
Py_DECREF(first_pyint);
return NULL;
}
PyObject* second_pyint = numpy::PyNumberToPyInt(second);
if (!second_pyint) {
PyErr_SetString(
PyExc_TypeError,
"Second pair item cannot be converted to int");
Py_DECREF(o);
Py_DECREF(first_pyint);
return NULL;
}
const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
if (first_value == -1 && PyErr_Occurred()) {
Py_DECREF(o);
Py_DECREF(first_pyint);
Py_DECREF(second_pyint);
return NULL;
}
const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
if (second_value == -1 && PyErr_Occurred()) {
Py_DECREF(o);
Py_DECREF(first_pyint);
Py_DECREF(second_pyint);
return NULL;
}
temps.push_back(std::make_pair(first_value, second_value));
Py_DECREF(o);
}
$1 = temps;
}
// ConvolutionDimensionNumbers
%typemap(in) const ConvolutionDimensionNumbers&
(ConvolutionDimensionNumbers dimension_numbers) {
int64 value;
if (!GetIntAttr($input, "input_batch_dimension", &value)) {
return NULL;
}
dimension_numbers.set_input_batch_dimension(value);
if (!GetIntAttr($input, "input_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_input_feature_dimension(value);
if (!GetIntAttr($input, "output_batch_dimension", &value)) {
return NULL;
}
dimension_numbers.set_output_batch_dimension(value);
if (!GetIntAttr($input, "output_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_output_feature_dimension(value);
if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_kernel_output_feature_dimension(value);
if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_kernel_input_feature_dimension(value);
PyObject* o;
int length;
o = PyObject_GetAttrString($input, "input_spatial_dimensions");
if (!o) {
return NULL;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
return NULL;
}
dimension_numbers.add_input_spatial_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(o);
o = PyObject_GetAttrString($input, "kernel_spatial_dimensions");
if (!o) {
return NULL;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
return NULL;
}
dimension_numbers.add_kernel_spatial_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(o);
o = PyObject_GetAttrString($input, "output_spatial_dimensions");
if (!o) {
return NULL;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
return NULL;
}
dimension_numbers.add_output_spatial_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(o);
$1 = &dimension_numbers;
}
%ignoreall %ignoreall
%unignore xla; %unignore xla;
%unignore xla::swig; %unignore xla::swig;
@ -314,6 +519,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Lt; %unignore xla::swig::LocalComputationBuilder::Lt;
%unignore xla::swig::LocalComputationBuilder::Le; %unignore xla::swig::LocalComputationBuilder::Le;
%unignore xla::swig::LocalComputationBuilder::Dot; %unignore xla::swig::LocalComputationBuilder::Dot;
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
%unignore xla::swig::LocalComputationBuilder::Add; %unignore xla::swig::LocalComputationBuilder::Add;
%unignore xla::swig::LocalComputationBuilder::Sub; %unignore xla::swig::LocalComputationBuilder::Sub;
%unignore xla::swig::LocalComputationBuilder::Mul; %unignore xla::swig::LocalComputationBuilder::Mul;

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import enum # pylint: disable=g-bad-import-order
import itertools import itertools
import numpy as np import numpy as np
@ -25,6 +26,12 @@ import numpy as np
from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.python import pywrap_xla as c_api
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
_UNARY_OPS = [ _UNARY_OPS = [
'Not', 'Not',
'Abs', 'Abs',
@ -564,6 +571,79 @@ class ComputationBuilder(object):
return _wrap_data_handle( return _wrap_data_handle(
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
def Conv(self, lhs, rhs, window_strides, padding):
"""Enqueues a Conv operation onto the computation.
Args:
lhs: ComputationDataHandle for the rank N+2 array of inputs.
rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
Returns: a ComputationDataHandle representing the Conv operation.
"""
if padding == PaddingType.SAME:
lhs_dims = self.GetShape(lhs).dimensions()
rhs_dims = self.GetShape(rhs).dimensions()
in_shape, filter_shape = lhs_dims[2:], rhs_dims[2:]
out_shape = np.ceil(np.true_divide(in_shape, window_strides)).astype(int)
pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0)
for out_size, stride, filter_size, in_size
in zip(out_shape, window_strides, filter_shape, in_shape)]
pads = [(pad_size // 2, pad_size - pad_size // 2)
for pad_size in pad_sizes]
else:
pads = [(0, 0)] * len(window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return _wrap_data_handle(
self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
_unwrap_data_handle(rhs),
window_strides,
pads,
(),
(),
dimension_numbers))
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
lhs: ComputationDataHandle for the rank N+2 array of inputs.
rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
window_strides: length-N array-like of kernel strides.
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
"""
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return _wrap_data_handle(
self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
_unwrap_data_handle(rhs),
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers))
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
nd = num_spatial_dims
dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
dimension_numbers.input_batch_dimension = 0
dimension_numbers.input_feature_dimension = 1
dimension_numbers.output_batch_dimension = 0
dimension_numbers.output_feature_dimension = 1
dimension_numbers.kernel_output_feature_dimension = 0
dimension_numbers.kernel_input_feature_dimension = 1
dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
return dimension_numbers
def _forward_methods_to_local_builder(): def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API. """Forward remaining ComputationBuilder methods to the C API.

View File

@ -386,6 +386,46 @@ class SingleOpTest(LocalComputationTest):
c.Dot(c.Constant(lhs), c.Constant(rhs)) c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testConvF32Same(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(c.Constant(lhs), c.Constant(rhs),
[1, 1], xla_client.PaddingType.SAME)
result = np.array([[[[640., 700., 760., 300.],
[880., 940., 1000., 380.],
[1120., 1180., 1240., 460.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvF32Valid(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(c.Constant(lhs), c.Constant(rhs),
[2, 1], xla_client.PaddingType.VALID)
result = np.array([[[[640., 700., 760.],
[1120., 1180., 1240.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvWithGeneralPaddingF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testBooleanNot(self): def testBooleanNot(self):
c = self._NewComputation() c = self._NewComputation()
arr = NumpyArrayBool([True, False, True]) arr = NumpyArrayBool([True, False, True])

View File

@ -1182,6 +1182,11 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
} }
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
pad, HloInstruction::CreateBroadcast(pad->shape(),
pad->mutable_operand(1), {}));
}
// Eliminate nop pads (padding all zero), and replace a pad with negative // Eliminate nop pads (padding all zero), and replace a pad with negative
// padding with a pad with non-negative padding followed by a slice. // padding with a pad with non-negative padding followed by a slice.
bool all_zero = true; bool all_zero = true;
@ -1624,6 +1629,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
Status AlgebraicSimplifierVisitor::HandleReduceWindow( Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* reduce_window) { HloInstruction* reduce_window) {
if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateBroadcast(reduce_window->shape(),
reduce_window->mutable_operand(1), {}));
}
auto operand = reduce_window->mutable_operand(0); auto operand = reduce_window->mutable_operand(0);
const Window& window = reduce_window->window(); const Window& window = reduce_window->window();
auto function = reduce_window->to_apply(); auto function = reduce_window->to_apply();
@ -1694,7 +1705,6 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
auto operand = transpose->mutable_operand(0); auto operand = transpose->mutable_operand(0);
if (std::is_sorted(transpose->dimensions().begin(), if (std::is_sorted(transpose->dimensions().begin(),
transpose->dimensions().end())) { transpose->dimensions().end())) {
VLOG(10) << "deleting no-op transpose"; VLOG(10) << "deleting no-op transpose";
@ -1721,6 +1731,18 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution) { HloInstruction* convolution) {
auto lhs = convolution->mutable_operand(0); auto lhs = convolution->mutable_operand(0);
auto rhs = convolution->mutable_operand(1); auto rhs = convolution->mutable_operand(1);
if (ShapeUtil::HasZeroElements(lhs->shape()) ||
ShapeUtil::HasZeroElements(rhs->shape())) {
return ReplaceWithNewInstruction(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
{}));
}
const auto& window = convolution->window(); const auto& window = convolution->window();
if (!enable_conv_simplification_) { if (!enable_conv_simplification_) {
return Status::OK(); return Status::OK();
@ -1813,17 +1835,14 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// We already checked feature_dimension is most minor, so data in input_shape // We already checked feature_dimension is most minor, so data in input_shape
// and row-major {conv_width,input_channels} are bitwise identical. // and row-major {conv_width,input_channels} are bitwise identical.
const Shape new_input_shape = const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
input_shape.element_type(), {conv_width, input_channels}); input_shape.element_type(), {conv_width, input_channels});
// We already checked input_feature_dimension is more major than // We already checked input_feature_dimension is more major than
// output_feature_dimension, so data in filter_shape and row-major // output_feature_dimension, so data in filter_shape and row-major
// {input_channels,output_channels} are bitwise identical. // {input_channels,output_channels} are bitwise identical.
const Shape new_filter_shape = const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
filter_shape.element_type(), {input_channels, output_channels}); filter_shape.element_type(), {input_channels, output_channels});
const Shape dot_output_shape = const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
convolution_shape.element_type(), {conv_width, output_channels}); convolution_shape.element_type(), {conv_width, output_channels});
// We cannot insert bitcasts if the layouts will not be compatible. // We cannot insert bitcasts if the layouts will not be compatible.

View File

@ -816,6 +816,120 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1); 1);
} }
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.add_input_spatial_dimensions(1);
dnums.set_input_feature_dimension(2);
dnums.set_output_batch_dimension(0);
dnums.add_output_spatial_dimensions(1);
dnums.set_output_feature_dimension(2);
dnums.add_kernel_spatial_dimensions(0);
dnums.set_kernel_input_feature_dimension(1);
dnums.set_kernel_output_feature_dimension(2);
Window window;
WindowDimension* dim = window.add_dimensions();
dim->set_size(3);
dim->set_padding_low(0);
dim->set_padding_high(0);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
dim->set_window_reversal(false);
// Create add computation.
std::unique_ptr<HloModule> module = CreateNewModule();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
module->AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Convolution(lhs, rhs));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
Window window;
for (int64 i = 0; i < 2; ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_padding_low(1);
dim->set_padding_high(1);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
}
// Create add computation.
std::unique_ptr<HloModule> module = CreateNewModule();
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* p0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "p0"));
HloInstruction* p1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
add_computation = module->AddEmbeddedComputation(builder.Build());
}
builder.AddInstruction(HloInstruction::CreateReduceWindow(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
window, add_computation));
module->AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::ReduceWindow(param, op::Constant()));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
PaddingConfig padding;
for (int i = 0; i < 2; ++i) {
PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
dimension->set_edge_padding_low(1);
dimension->set_edge_padding_high(1);
dimension->set_interior_padding(0);
}
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
padding));
std::unique_ptr<HloModule> module = CreateNewModule();
module->AddEntryComputation(builder.Build());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Pad(param, op::Constant()));
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@ -1309,7 +1423,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
HloComputation::Builder builder(TestName()); HloComputation::Builder builder(TestName());
HloInstruction* param0 = HloInstruction* param0 =
builder.AddInstruction(HloInstruction::CreateParameter( builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}), 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
"param0")); "param0"));
HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary( HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(

View File

@ -28,8 +28,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
namespace se = ::perftools::gputools;
namespace xla { namespace xla {
StatusOr<GlobalDataHandle> AllocationTracker::Register( StatusOr<GlobalDataHandle> AllocationTracker::Register(

View File

@ -96,6 +96,26 @@ class ConvolutionThunk : public Thunk {
return !best_algorithm_.has_value(); return !best_algorithm_.has_value();
} }
// Return true if scratch memory is needed to execute the thunk, that is
// either the best algorithm hasn't been chosen or the best algorithm is not
// the same as the no-scratch algorithm. This is because that the execution
// of the thunk is asynchronous, and the scratch allocator goes out of
// scope before the thunk finishes execution. Returning true tells the stream
// executor to make future thunks wait for this thunk to avoid reusing the
// deallocated scratch memory until this thunk is done with it.
bool ShouldBlockFutureThunks() {
if (!best_algorithm_.has_value()) {
return true;
}
const perftools::gputools::dnn::AlgorithmDesc& best_alg =
best_algorithm_->algorithm();
const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg =
best_algorithm_->algorithm_no_scratch();
return (!best_alg.is_default() || !no_scratch_best_alg.is_default() ||
!(best_alg == no_scratch_best_alg));
}
private: private:
tensorflow::Status ConvolveWithTune( tensorflow::Status ConvolveWithTune(
const perftools::gputools::dnn::BatchDescriptor& input_descriptor, const perftools::gputools::dnn::BatchDescriptor& input_descriptor,

View File

@ -155,6 +155,7 @@ Status GpuExecutable::ExecuteThunks(
run_options->BorrowStream(main_stream->parent()->device_ordinal())); run_options->BorrowStream(main_stream->parent()->device_ordinal()));
} }
std::map<int32, const Thunk*> last_blocking_thunk_for_stream;
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event; std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
for (Thunk* thunk : thunk_schedule_->TotalOrder()) { for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
TF_RETURN_IF_ERROR(thunk->Initialize(*this)); TF_RETURN_IF_ERROR(thunk->Initialize(*this));
@ -167,10 +168,17 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
} }
if (last_blocking_thunk_for_stream.count(stream_no)) {
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event,
last_blocking_thunk_for_stream[stream_no])
.get());
}
// If this thunk requests it, wait for all currently-executing thunks to // If this thunk requests it, wait for all currently-executing thunks to
// finish. This is useful e.g. if the thunk is about to perform autotuning. // finish. This is useful e.g. if the thunk is about to perform autotuning.
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
last_blocking_thunk_for_stream.clear();
} }
profiler.StartOperation(); profiler.StartOperation();
@ -178,11 +186,14 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream " << thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no; << stream_no;
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
if (thunk_schedule_->Depended(thunk)) { if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) {
auto finish_event = MakeUnique<se::Event>(main_stream->parent()); auto finish_event = MakeUnique<se::Event>(main_stream->parent());
finish_event->Init(); finish_event->Init();
stream->ThenRecordEvent(finish_event.get()); stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event); thunk_to_finish_event[thunk] = std::move(finish_event);
if (thunk->ShouldBlockFutureThunks()) {
last_blocking_thunk_for_stream[stream_no] = thunk;
}
} }
profiler.FinishOperation(thunk->hlo_instruction()); profiler.FinishOperation(thunk->hlo_instruction());
} }

View File

@ -419,7 +419,7 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
(segs.size() == i ? shape.dimensions().size() : segs[i]), (segs.size() == i ? shape.dimensions().size() : segs[i]),
1, std::multiplies<int64>())); 1, std::multiplies<int64>()));
} }
return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
dimensions); dimensions);
} }
@ -442,11 +442,13 @@ std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
} }
} }
auto segs = ConsecutiveSegments(perm); auto segs = ConsecutiveSegments(perm);
Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a); Shape norm_a =
Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b); ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
Shape norm_b =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
if (3 == segs.size() && 0 == perm[0]) { if (3 == segs.size() && 0 == perm[0]) {
Shape reduced_a = MergeDimensions(segs, norm_a); Shape reduced_a = MergeDimensions(segs, norm_a);
Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
b.element_type(), b.element_type(),
Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
return std::make_tuple(true, reduced_a, reduced_b); return std::make_tuple(true, reduced_a, reduced_b);
@ -460,10 +462,11 @@ std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
bool AreShapesForTranspose021(const Shape& a, const Shape& b) { bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
return 3 == b.dimensions().size() && return 3 == b.dimensions().size() &&
ShapeUtil::Compatible( ShapeUtil::Compatible(
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a), ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
ShapeUtil::PermuteDimensions( ShapeUtil::PermuteDimensions(
{0, 2, 1}, {0, 2, 1},
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b))); ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
b)));
} }
// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
@ -495,9 +498,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
Shape input_shape = Shape input_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape()); ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input.GetShape());
Shape output_shape = Shape output_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape()); ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
output.GetShape());
input = input.CastToShape(input_shape, builder); input = input.CastToShape(input_shape, builder);
output = output.CastToShape(output_shape, builder); output = output.CastToShape(output_shape, builder);
@ -615,7 +620,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
builder))), builder))),
builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( ShapeUtil::MakeShapeWithDescendingLayout(
PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
builder); builder);
const llvm_ir::IrArray::Index input_tile_origin = ({ const llvm_ir::IrArray::Index input_tile_origin = ({
@ -811,14 +816,15 @@ Status IrEmitterUnnested::EmitColumnReduction(
// input_shape to normalized_input_shape and a reshape from // input_shape to normalized_input_shape and a reshape from
// normalized_input_shape to input_matrix_shape. // normalized_input_shape to input_matrix_shape.
const Shape normalized_input_shape = const Shape normalized_input_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
const std::vector<int64> transpose_dimension_mapping( const std::vector<int64> transpose_dimension_mapping(
input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
const Shape input_matrix_shape = const Shape input_matrix_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
input_shape.element_type(), {height, width}); {height, width});
const llvm_ir::IrArray::Index input_matrix_index( const llvm_ir::IrArray::Index input_matrix_index(
{y, x}, input_matrix_shape, &ir_builder_); {y, x}, input_matrix_shape, &ir_builder_);
const llvm_ir::IrArray::Index input_index = const llvm_ir::IrArray::Index input_index =
@ -1054,13 +1060,14 @@ Status IrEmitterUnnested::EmitRowReduction(
// from input_shape to normalized_input_shape and a reshape from // from input_shape to normalized_input_shape and a reshape from
// normalized_input_shape to input_3d_tensor_shape. // normalized_input_shape to input_3d_tensor_shape.
const Shape normalized_input_shape = const Shape normalized_input_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
const std::vector<int64> transpose_dimension_mapping( const std::vector<int64> transpose_dimension_mapping(
input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
const Shape input_3d_tensor_shape = const Shape input_3d_tensor_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
input_shape.element_type(), {depth, height, width}); {depth, height, width});
const llvm_ir::IrArray::Index input_3d_tensor_index( const llvm_ir::IrArray::Index input_3d_tensor_index(
{z, y, x}, input_3d_tensor_shape, &ir_builder_); {z, y, x}, input_3d_tensor_shape, &ir_builder_);
const llvm_ir::IrArray::Index input_index = const llvm_ir::IrArray::Index input_index =

View File

@ -83,6 +83,16 @@ class Thunk {
return false; return false;
} }
// Indicates whether thunks scheduled after this one should wait for this one
// to complete before running. For example, a convolution thunk creates a
// scratch allocator, then kicks off a convolution in cudnn via the stream
// executor. When the stream executor call returns, the scratch allocator goes
// out of scope, and the scratch memory is deallocated. In this case, the
// convolution thunk needs to return true so that future thunks wait for the
// convolution thunk to avoid reusing the deallocated memory until the
// convolution thunk is done with it.
virtual bool ShouldBlockFutureThunks() { return false; }
// Execute the kernel for the thunk on the given stream. This method must be // Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's // called after Initialize and can be called multiple times over Thunk's
// lifetime. Stream argument must be non-null. // lifetime. Stream argument must be non-null.

View File

@ -1361,7 +1361,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4); std::vector<int64> input_dims(6, 4);
std::unique_ptr<Literal> arg_literal = std::unique_ptr<Literal> arg_literal =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f); Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction = HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
@ -1414,7 +1414,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4}; std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal = std::unique_ptr<Literal> result_literal =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 8.0f); Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
LiteralTestUtil::ExpectEqual(*result_literal, *result); LiteralTestUtil::ExpectEqual(*result_literal, *result);
} }

View File

@ -478,7 +478,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
} else if (instruction->opcode() == HloOpcode::kCustomCall) { } else if (instruction->opcode() == HloOpcode::kCustomCall) {
// Add constraints for kCustomCall instruction operands and instructions. // Add constraints for kCustomCall instruction operands and instructions.
// For now we only support major-first layouts for all inputs and outputs. // For now we only support major-first layouts for all inputs and outputs.
Shape result_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
instruction->shape().element_type(), instruction->shape().element_type(),
AsInt64Slice(instruction->shape().dimensions())); AsInt64Slice(instruction->shape().dimensions()));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -491,7 +491,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
} }
Shape row_major_operand_shape = Shape row_major_operand_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( ShapeUtil::MakeShapeWithDescendingLayout(
operand_shape.element_type(), operand_shape.element_type(),
AsInt64Slice(operand_shape.dimensions())); AsInt64Slice(operand_shape.dimensions()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout( TF_RETURN_IF_ERROR(constraints->SetOperandLayout(

View File

@ -189,20 +189,21 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
.ValueOrDie(); .ValueOrDie();
} }
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) { PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
std::vector<int64> layout(dimensions.size()); std::vector<int64> layout(dimensions.size());
std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0)); std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
return MakeShapeWithLayout(element_type, dimensions, layout); return MakeShapeWithLayout(element_type, dimensions, layout);
} }
/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout( /* static */ Shape
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape) { const Shape& shape) {
std::vector<int64> dims(shape.dimensions_size()); std::vector<int64> dims(shape.dimensions_size());
for (int i = 0; i < shape.dimensions_size(); ++i) { for (int i = 0; i < shape.dimensions_size(); ++i) {
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
} }
return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); return MakeShapeWithDescendingLayout(shape.element_type(), dims);
} }
/* static */ void ShapeUtil::PopulateShape( /* static */ void ShapeUtil::PopulateShape(
@ -1148,9 +1149,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
// as input_shape/output_shape and the dimension-0-major layout. These two // as input_shape/output_shape and the dimension-0-major layout. These two
// shapes are used for conversion between logical linear indices and // shapes are used for conversion between logical linear indices and
// multi-dimensional indices. // multi-dimensional indices.
Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) {

View File

@ -268,14 +268,18 @@ class ShapeUtil {
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major); tensorflow::gtl::ArraySlice<int64> minor_to_major);
// Constructs a new shape with major-first layout. // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
static Shape MakeShapeWithMonotonicDim0MajorLayout( static Shape MakeShapeWithDescendingLayout(
PrimitiveType element_type, PrimitiveType element_type,
tensorflow::gtl::ArraySlice<int64> dimensions); tensorflow::gtl::ArraySlice<int64> dimensions);
// Returns a new shape with major-first layout that has the same layout of // Returns a new Shape based on the given Shape with low-dimension-major
// elements with a different shape. // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); // rearranged so that it has the same in-memory layout as the given shape.
//
// For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape);
// As MakeShape, but the object to write to is passed in. // As MakeShape, but the object to write to is passed in.
static void PopulateShape(PrimitiveType element_type, static void PopulateShape(PrimitiveType element_type,

View File

@ -393,7 +393,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
auto shape = ShapeUtil::MakeShape(F32, input_dims); auto shape = ShapeUtil::MakeShape(F32, input_dims);
std::unique_ptr<Literal> arg_literal = std::unique_ptr<Literal> arg_literal =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f); Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
@ -402,7 +402,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8}; std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
std::unique_ptr<Literal> expected = std::unique_ptr<Literal> expected =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 9.0f); Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
} }

View File

@ -397,6 +397,7 @@ py_test(
srcs = ["scan_dataset_op_test.py"], srcs = ["scan_dataset_op_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
@ -491,6 +492,25 @@ py_test(
], ],
) )
py_test(
name = "unique_dataset_op_test",
size = "small",
srcs = ["unique_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/contrib/stateless",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//third_party/py/numpy",
],
)
py_test( py_test(
name = "zip_dataset_op_test", name = "zip_dataset_op_test",
size = "small", size = "small",

View File

@ -735,6 +735,20 @@ class BatchDatasetSerializationTest(
lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
num_outputs) num_outputs)
def _build_dataset_dense_to_sparse(self, components):
return dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.fill([x], x)).apply(
batching.dense_to_sparse_batch(4, [12]))
def testDenseToSparseBatchDatasetCore(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
num_outputs = len(components) // 4
self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
lambda: self._build_dataset_dense_to_sparse(diff_comp),
num_outputs)
class PaddedBatchDatasetSerializationTest( class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase): dataset_serialization_test_base.DatasetSerializationTestBase):

View File

@ -21,6 +21,7 @@ import itertools
import numpy as np import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.contrib.data.python.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase):
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
class ScanDatasetSerialzationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_elements):
return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
def testScanCore(self):
num_output = 5
self.run_core_tests(lambda: self._build_dataset(num_output),
lambda: self._build_dataset(2), num_output)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -0,0 +1,96 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import unique
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class UniqueDatasetTest(test.TestCase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
Args:
dtype: The `dtype` of the elements in each test case.
test_cases: A list of pairs of lists. The first component is the test
input that will be passed to the transformation; the second component
is the expected sequence of outputs from the transformation.
"""
# The `current_test_case` will be updated when we loop over `test_cases`
# below; declare it here so that the generator can capture it once.
current_test_case = []
dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
dtype).apply(unique.unique())
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with self.test_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
for element in expected:
if dtype == dtypes.string:
element = compat.as_bytes(element)
self.assertAllEqual(element, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testSimpleInt(self):
for dtype in [dtypes.int32, dtypes.int64]:
self._testSimpleHelper(dtype, [
([], []),
([1], [1]),
([1, 1, 1, 1, 1, 1, 1], [1]),
([1, 2, 3, 4], [1, 2, 3, 4]),
([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
])
def testSimpleString(self):
self._testSimpleHelper(dtypes.string, [
([], []),
(["hello"], ["hello"]),
(["hello", "hello", "hello"], ["hello"]),
(["hello", "world"], ["hello", "world"]),
(["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
])
class UniqueSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def testUnique(self):
def build_dataset(num_elements, unique_elem_range):
return dataset_ops.Dataset.range(num_elements).map(
lambda x: x % unique_elem_range).apply(unique.unique())
self.run_core_tests(lambda: build_dataset(200, 100),
lambda: build_dataset(40, 100), 100)
if __name__ == "__main__":
test.main()

View File

@ -105,6 +105,7 @@ py_library(
"resampling.py", "resampling.py",
"scan_ops.py", "scan_ops.py",
"stats_ops.py", "stats_ops.py",
"unique.py",
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [

View File

@ -0,0 +1,82 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unique element dataset transformations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_dataset_ops
def unique():
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
Use this transformation to produce a dataset that contains one instance of
each unique element in the input. For example:
```python
dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
# Using `unique()` will drop the duplicate elements.
dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 }
```
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return UniqueDataset(dataset)
return _apply_fn
class UniqueDataset(dataset_ops.Dataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
super(UniqueDataset, self).__init__()
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
raise TypeError(
"`tf.contrib.data.unique()` only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
return gen_dataset_ops.unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)))
@property
def output_classes(self):
return self._input_dataset.output_classes
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types

View File

@ -110,6 +110,7 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",

View File

@ -22,11 +22,14 @@ import numpy as np
import numpy.random as npr import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import utils from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -95,6 +98,18 @@ class SubGraphTest(test.TestCase):
filtered_list = sub_graph.filter_list(input_list) filtered_list = sub_graph.filter_list(input_list)
self.assertEqual(filtered_list, [b]) self.assertEqual(filtered_list, [b])
def testVariableUses(self):
with ops.Graph().as_default():
var = variable_scope.get_variable('var', shape=[10, 10])
resource_var = variable_scope.get_variable(
'resource_var', shape=[10, 10], use_resource=True)
x = array_ops.zeros([3, 10])
z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
z1 = math_ops.matmul(x, resource_var)
sub_graph = utils.SubGraph((z0, z1))
self.assertEqual(2, sub_graph.variable_uses(var))
self.assertEqual(1, sub_graph.variable_uses(resource_var))
class UtilsTest(test.TestCase): class UtilsTest(test.TestCase):
@ -253,6 +268,25 @@ class UtilsTest(test.TestCase):
np_inv = np.linalg.inv(x + damp * np.eye(size)) np_inv = np.linalg.inv(x + damp * np.eye(size))
self.assertAllClose(sess.run(tf_inv), np_inv) self.assertAllClose(sess.run(tf_inv), np_inv)
def testCrossReplicaMean(self):
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(4):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertNotEqual(mean, tensor)
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(1):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertEqual(mean, tensor)
with ops.Graph().as_default():
with self.assertRaises(ValueError): # Outside of TPU context.
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -196,6 +196,7 @@ py_library(
srcs = ["utils.py"], srcs = ["utils.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",

View File

@ -267,6 +267,10 @@ class FisherFactor(object):
new_cov = math_ops.add_n( new_cov = math_ops.add_n(
tuple(self._compute_new_cov(idx) for idx in range(self._num_sources))) tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
# Synchronize value across all TPU cores.
if utils.on_tpu():
new_cov = utils.cross_replica_mean(new_cov)
return moving_averages.assign_moving_average( return moving_averages.assign_moving_average(
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
@ -27,6 +29,8 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
# Method used for inverting matrices. # Method used for inverting matrices.
POSDEF_INV_METHOD = "cholesky" POSDEF_INV_METHOD = "cholesky"
@ -226,11 +230,13 @@ class SubGraph(object):
""" """
def __init__(self, outputs): def __init__(self, outputs):
# Set of all ancestor Tensors, Ops to 'outputs'.
self._members = set() self._members = set()
self._recurse_add(outputs) self._recurse_add(outputs)
def _recurse_add(self, nodes): def _recurse_add(self, nodes):
"""Recursively adds all of nodes' ancestors."""
for node in nodes: for node in nodes:
if node in self._members: if node in self._members:
continue continue
@ -246,8 +252,25 @@ class SubGraph(object):
return node in self._members return node in self._members
def variable_uses(self, var): def variable_uses(self, var):
"""Computes number of times a variable is used.""" """Computes number of times a variable is used.
return len(self._members.intersection(set(var.value().consumers())))
Args:
var: Variable or ResourceVariable instance.
Returns:
Number of times a variable is used within this subgraph.
Raises:
ValueError: If 'var' is not a variable type.
"""
if isinstance(var, resource_variable_ops.ResourceVariable):
var = var.handle
elif isinstance(var, variables.Variable):
var = var.value()
else:
raise ValueError("%s does not appear to be a variable." % str(var))
return len(self._members.intersection(set(var.consumers())))
def filter_list(self, node_list): def filter_list(self, node_list):
"""Filters 'node_list' to nodes in this subgraph.""" """Filters 'node_list' to nodes in this subgraph."""
@ -292,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
return dysdx return dysdx
def on_tpu():
"""Returns True when building a TPU computation."""
return tpu_function.get_tpu_context().number_of_shards is not None
def cross_replica_mean(tensor, name=None):
"""Takes mean value of a Tensor across all TPU cores.
Args:
tensor: Tensor to be synchronized.
name: None or string. Name of Op.
Returns:
Average of Tensor across all TPU cores.
Raises:
ValueError: If called outside of TPU context.
"""
with ops.name_scope(name, "cross_replica_mean", [tensor]):
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
raise ValueError(
"Cannot take cross_replica_mean() outside of TPU Context.")
if num_shards == 1:
return tensor
return tpu_ops.cross_replica_sum(tensor / num_shards)
# TODO(b/69623235): Add a function for finding tensors that share gradients # TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations. # to eliminate redundant fisher factor computations.

View File

@ -179,6 +179,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
ConcatenateTensorBuffers<ArrayDataType::kInt64>( ConcatenateTensorBuffers<ArrayDataType::kInt64>(
input_arrays, concatenation_axis, &concatenated_array); input_arrays, concatenation_axis, &concatenated_array);
break; break;
case ArrayDataType::kString:
ConcatenateTensorBuffers<ArrayDataType::kString>(
input_arrays, concatenation_axis, &concatenated_array);
break;
default: default:
LOG(FATAL) << "ArrayDataType not supported"; LOG(FATAL) << "ArrayDataType not supported";
} }

View File

@ -52,6 +52,7 @@ using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT; using tensorflow::DT_FLOAT;
using tensorflow::DT_INT32; using tensorflow::DT_INT32;
using tensorflow::DT_INT64; using tensorflow::DT_INT64;
using tensorflow::DT_STRING;
using tensorflow::DT_UINT8; using tensorflow::DT_UINT8;
using tensorflow::GraphDef; using tensorflow::GraphDef;
using tensorflow::NodeDef; using tensorflow::NodeDef;
@ -135,6 +136,8 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kInt32; return ArrayDataType::kInt32;
else if (dtype == DT_INT64) else if (dtype == DT_INT64)
return ArrayDataType::kInt64; return ArrayDataType::kInt64;
else if (dtype == DT_STRING)
return ArrayDataType::kString;
else else
LOG(INFO) << "Unsupported data type in placehoder op: " << dtype; LOG(INFO) << "Unsupported data type in placehoder op: " << dtype;
return ArrayDataType::kNone; return ArrayDataType::kNone;
@ -236,6 +239,27 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
} }
} }
void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
ImportShape(input_shape.dim(), output_array->mutable_shape());
int input_flat_size = 1;
for (int k = 0; k < input_shape.dim_size(); k++) {
input_flat_size *= input_shape.dim(k).size();
}
auto& output_string_data =
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
output_string_data.resize(input_flat_size);
if (input_flat_size != input_tensor.string_val_size()) {
LOG(FATAL) << "Input_content string_val doesn't have the right "
"dimensions for this string tensor.";
}
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
}
// Count the number of inputs of a given node. If // Count the number of inputs of a given node. If
// `tf_import_flags.drop_control_dependency` is true, count the number of // `tf_import_flags.drop_control_dependency` is true, count the number of
// non-control-dependency inputs. // non-control-dependency inputs.
@ -261,23 +285,30 @@ void ConvertConstOperator(const NodeDef& node,
const auto dtype = GetDataTypeAttr(node, "dtype"); const auto dtype = GetDataTypeAttr(node, "dtype");
auto& array = model->GetOrCreateArray(node.name()); auto& array = model->GetOrCreateArray(node.name());
array.data_type = dtype == DT_FLOAT switch (dtype) {
? ArrayDataType::kFloat case DT_FLOAT:
: dtype == DT_INT32 array.data_type = ArrayDataType::kFloat;
? ArrayDataType::kInt32
: dtype == DT_INT64 ? ArrayDataType::kInt64
: ArrayDataType::kNone;
if (dtype == DT_FLOAT) {
ImportFloatArray(tensor, &array); ImportFloatArray(tensor, &array);
} else if (dtype == DT_INT32) { break;
case DT_INT32:
array.data_type = ArrayDataType::kInt32;
ImportInt32Array(tensor, &array); ImportInt32Array(tensor, &array);
} else if (dtype == DT_INT64) { break;
case DT_INT64:
array.data_type = ArrayDataType::kInt64;
ImportInt64Array(tensor, &array); ImportInt64Array(tensor, &array);
} else { break;
// do nothing, silently ignore the Const data. For example, there are consts case DT_STRING:
// of string type. We just make a dummy buffer to indicate that this array array.data_type = ArrayDataType::kString;
// does not rely on external input. ImportStringArray(tensor, &array);
break;
default:
array.data_type = ArrayDataType::kNone;
// do nothing, silently ignore the Const data.
// We just make a dummy buffer to indicate that
// this array does not rely on external input.
array.GetMutableBuffer<ArrayDataType::kNone>(); array.GetMutableBuffer<ArrayDataType::kNone>();
break;
} }
} }
@ -1191,7 +1222,7 @@ void ConvertGatherOperator(const NodeDef& node,
CHECK_EQ(node.op(), "Gather"); CHECK_EQ(node.op(), "Gather");
CHECK_EQ(GetInputsCount(node, tf_import_flags), 2); CHECK_EQ(GetInputsCount(node, tf_import_flags), 2);
const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
CHECK(indices_data_type == DT_INT32); CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
auto* op = new GatherOperator; auto* op = new GatherOperator;
op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1)); op->inputs.push_back(node.input(1));

View File

@ -153,7 +153,15 @@ enum class AxesOrder {
// because we'll be dropping the array anyway (e.g. some exotic array types // because we'll be dropping the array anyway (e.g. some exotic array types
// may be involved only in debug-only subgraphs that we may not be interested // may be involved only in debug-only subgraphs that we may not be interested
// in actually supporting). // in actually supporting).
enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 }; enum class ArrayDataType {
kNone,
kBool,
kFloat,
kUint8,
kInt32,
kInt64,
kString
};
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
template <ArrayDataType A> template <ArrayDataType A>
@ -182,6 +190,10 @@ template <>
struct DataTypeImpl<ArrayDataType::kInt64> { struct DataTypeImpl<ArrayDataType::kInt64> {
typedef int64 Type; typedef int64 Type;
}; };
template <>
struct DataTypeImpl<ArrayDataType::kString> {
typedef string Type;
};
template <ArrayDataType A> template <ArrayDataType A>
using DataType = typename DataTypeImpl<A>::Type; using DataType = typename DataTypeImpl<A>::Type;

View File

@ -53,6 +53,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
return ::tflite::TensorType_INT32; return ::tflite::TensorType_INT32;
case ArrayDataType::kUint8: case ArrayDataType::kUint8:
return ::tflite::TensorType_UINT8; return ::tflite::TensorType_UINT8;
case ArrayDataType::kString:
return ::tflite::TensorType_STRING;
default: default:
// FLOAT32 is filled for unknown data types. // FLOAT32 is filled for unknown data types.
// TODO(ycling): Implement type inference in TF Lite interpreter. // TODO(ycling): Implement type inference in TF Lite interpreter.
@ -66,6 +68,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
return ArrayDataType::kFloat; return ArrayDataType::kFloat;
case ::tflite::TensorType_INT32: case ::tflite::TensorType_INT32:
return ArrayDataType::kInt32; return ArrayDataType::kInt32;
case ::tflite::TensorType_STRING:
return ArrayDataType::kString;
case ::tflite::TensorType_UINT8: case ::tflite::TensorType_UINT8:
return ArrayDataType::kUint8; return ArrayDataType::kUint8;
default: default:
@ -82,6 +86,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
return CopyBuffer<ArrayDataType::kFloat>(array, builder); return CopyBuffer<ArrayDataType::kFloat>(array, builder);
case ArrayDataType::kInt32: case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder); return CopyBuffer<ArrayDataType::kInt32>(array, builder);
case ArrayDataType::kString:
return CopyBuffer<ArrayDataType::kString>(array, builder);
case ArrayDataType::kUint8: case ArrayDataType::kUint8:
return CopyBuffer<ArrayDataType::kUint8>(array, builder); return CopyBuffer<ArrayDataType::kUint8>(array, builder);
default: default:
@ -99,6 +105,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
return CopyBuffer<ArrayDataType::kFloat>(buffer, array); return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
case ::tflite::TensorType_INT32: case ::tflite::TensorType_INT32:
return CopyBuffer<ArrayDataType::kInt32>(buffer, array); return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
case ::tflite::TensorType_STRING:
return CopyBuffer<ArrayDataType::kString>(buffer, array);
case ::tflite::TensorType_UINT8: case ::tflite::TensorType_UINT8:
return CopyBuffer<ArrayDataType::kUint8>(buffer, array); return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
default: default:

View File

@ -316,6 +316,9 @@ void LogArray(int log_level, const Model& model, const string& name) {
case ArrayDataType::kUint8: case ArrayDataType::kUint8:
VLOG(log_level) << " Data type: kUint8"; VLOG(log_level) << " Data type: kUint8";
break; break;
case ArrayDataType::kString:
VLOG(log_level) << " Data type: kString";
break;
default: default:
VLOG(log_level) << " Data type: other (numerical value: " VLOG(log_level) << " Data type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")"; << static_cast<int>(array.data_type) << ")";
@ -334,6 +337,9 @@ void LogArray(int log_level, const Model& model, const string& name) {
case ArrayDataType::kUint8: case ArrayDataType::kUint8:
VLOG(log_level) << " Final type: kUint8"; VLOG(log_level) << " Final type: kUint8";
break; break;
case ArrayDataType::kString:
VLOG(log_level) << " Final type: kString";
break;
default: default:
VLOG(log_level) << " Final type: other (numerical value: " VLOG(log_level) << " Final type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")"; << static_cast<int>(array.data_type) << ")";
@ -1253,6 +1259,11 @@ int ElementSize(ArrayDataType data_type) {
return 1; return 1;
case ArrayDataType::kInt64: case ArrayDataType::kInt64:
return 8; return 8;
// Usually not critical limitation because strings are only input and/or
// output.
case ArrayDataType::kString:
LOG(FATAL) << "Transient arrays with strings are not supported yet";
return 0;
default: default:
LOG(FATAL) << "Should not get here."; LOG(FATAL) << "Should not get here.";
return 0; return 0;

View File

@ -82,22 +82,6 @@ py_test(
], ],
) )
py_test(
name = "utils_test",
size = "small",
srcs = ["python/saved_model/utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":saved_model_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:signature_constants",
"//tensorflow/python/saved_model:tag_constants",
],
)
filegroup( filegroup(
name = "all_files", name = "all_files",
srcs = glob( srcs = glob(

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "UniqueDataset"
summary: "Creates a dataset that contains the unique elements of `input_dataset`."
}

View File

@ -35,25 +35,41 @@ bool IsAdd(const NodeDef& node) {
bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; } bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
bool IsAnyDiv(const NodeDef& node) { bool IsAnyDiv(const NodeDef& node) {
return node.op() == "RealDiv" || node.op() == "Div" || return node.op() == "RealDiv" || node.op() == "Div" ||
node.op() == "FloorDiv" || node.op() == "TruncateDiv"; node.op() == "FloorDiv" || node.op() == "TruncateDiv";
} }
bool IsApproximateEqual(const NodeDef& node) {
return node.op() == "ApproximateEqual";
}
bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; } bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; } bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
bool IsBiasAdd(const NodeDef& node) { bool IsBiasAdd(const NodeDef& node) {
return node.op() == "BiasAdd" || node.op() == "BiasAddV1"; return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
} }
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; } bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; } bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
bool IsConstant(const NodeDef& node) { return node.op() == "Const"; } bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) { bool IsConv2DBackpropFilter(const NodeDef& node) {
@ -92,39 +108,77 @@ bool IsEnter(const NodeDef& node) {
return op == "Enter" || op == "RefEnter"; return op == "Enter" || op == "RefEnter";
} }
bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
bool IsExit(const NodeDef& node) { bool IsExit(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Exit" || op == "RefExit"; return op == "Exit" || op == "RefExit";
} }
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; } bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
bool IsFusedBatchNormGradV1(const NodeDef& node) { bool IsFusedBatchNormGrad(const NodeDef& node) {
return node.op() == "FusedBatchNormGrad"; const auto& op = node.op();
return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
} }
bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
bool IsIdentity(const NodeDef& node) { bool IsIdentity(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Identity" || op == "RefIdentity"; return op == "Identity" || op == "RefIdentity";
} }
bool IsIdentityN(const NodeDef& node) {
const auto& op = node.op();
return op == "IdentityN";
}
bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; } bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
bool IsMatMul(const NodeDef& node) { bool IsMatMul(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" || return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
op == "SparseMatMul"; op == "SparseMatMul";
} }
bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
bool IsMerge(const NodeDef& node) { bool IsMerge(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Merge" || op == "RefMerge"; return op == "Merge" || op == "RefMerge";
} }
bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
bool IsMul(const NodeDef& node) { return node.op() == "Mul"; } bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; } bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
bool IsNextIteration(const NodeDef& node) { bool IsNextIteration(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "NextIteration" || op == "RefNextIteration"; return op == "NextIteration" || op == "RefNextIteration";
@ -138,6 +192,12 @@ bool IsPlaceholder(const NodeDef& node) {
op == "PlaceholderWithDefault"; op == "PlaceholderWithDefault";
} }
bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
bool IsReciprocalGrad(const NodeDef& node) { bool IsReciprocalGrad(const NodeDef& node) {
@ -209,12 +269,18 @@ bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
bool IsVariable(const NodeDef& node) { bool IsVariable(const NodeDef& node) {
const auto& op = node.op(); const auto& op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
op == "VarHandleOp" || op == "ReadVariableOp"; op == "VarHandleOp" || op == "ReadVariableOp";
} }
bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
namespace { namespace {
bool GetBoolAttr(const NodeDef& node, const string& name) { bool GetBoolAttr(const NodeDef& node, const string& name) {
return node.attr().count(name) > 0 && node.attr().at(name).b(); return node.attr().count(name) > 0 && node.attr().at(name).b();
@ -284,5 +350,10 @@ bool IsValuePreserving(const NodeDef& node) {
return value_preserving_ops.count(node.op()) > 0; return value_preserving_ops.count(node.op()) > 0;
} }
bool HasOpDef(const NodeDef& node) {
const OpDef* op_def = nullptr;
return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
}
} // namespace grappler } // namespace grappler
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -24,11 +24,18 @@ namespace grappler {
bool IsAdd(const NodeDef& node); bool IsAdd(const NodeDef& node);
bool IsAddN(const NodeDef& node); bool IsAddN(const NodeDef& node);
bool IsAngle(const NodeDef& node);
bool IsAnyDiv(const NodeDef& node); bool IsAnyDiv(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node); bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node); bool IsAssert(const NodeDef& node);
bool IsAtan2(const NodeDef& node);
bool IsBiasAdd(const NodeDef& node); bool IsBiasAdd(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node); bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node); bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node); bool IsConstant(const NodeDef& node);
bool IsConv2D(const NodeDef& node); bool IsConv2D(const NodeDef& node);
@ -41,18 +48,38 @@ bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node); bool IsDiv(const NodeDef& node);
bool IsEluGrad(const NodeDef& node); bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node); bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
bool IsExit(const NodeDef& node); bool IsExit(const NodeDef& node);
bool IsFloorDiv(const NodeDef& node);
bool IsFloorMod(const NodeDef& node); bool IsFloorMod(const NodeDef& node);
bool IsFusedBatchNormGradV1(const NodeDef& node); bool IsFusedBatchNormGrad(const NodeDef& node);
bool IsGreater(const NodeDef& node);
bool IsGreaterEqual(const NodeDef& node);
bool IsIdentity(const NodeDef& node); bool IsIdentity(const NodeDef& node);
bool IsIdentityN(const NodeDef& node);
bool IsIgamma(const NodeDef& node);
bool IsIgammac(const NodeDef& node);
bool IsImag(const NodeDef& node);
bool IsInvGrad(const NodeDef& node); bool IsInvGrad(const NodeDef& node);
bool IsLess(const NodeDef& node);
bool IsLessEqual(const NodeDef& node);
bool IsLogicalAnd(const NodeDef& node);
bool IsLogicalNot(const NodeDef& node);
bool IsLogicalOr(const NodeDef& node);
bool IsMaximum(const NodeDef& node);
bool IsMerge(const NodeDef& node); bool IsMerge(const NodeDef& node);
bool IsMinimum(const NodeDef& node);
bool IsMod(const NodeDef& node);
bool IsMul(const NodeDef& node); bool IsMul(const NodeDef& node);
bool IsMatMul(const NodeDef& node); bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node); bool IsNextIteration(const NodeDef& node);
bool IsPad(const NodeDef& node); bool IsPad(const NodeDef& node);
bool IsNoOp(const NodeDef& node); bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node); bool IsPlaceholder(const NodeDef& node);
bool IsPolygamma(const NodeDef& node);
bool IsPow(const NodeDef& node);
bool IsReal(const NodeDef& node);
bool IsRealDiv(const NodeDef& node); bool IsRealDiv(const NodeDef& node);
bool IsRelu6Grad(const NodeDef& node); bool IsRelu6Grad(const NodeDef& node);
bool IsReluGrad(const NodeDef& node); bool IsReluGrad(const NodeDef& node);
@ -80,7 +107,10 @@ bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node); bool IsSwitch(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node); bool IsTanhGrad(const NodeDef& node);
bool IsTranspose(const NodeDef& node); bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
bool IsTruncateMod(const NodeDef& node);
bool IsVariable(const NodeDef& node); bool IsVariable(const NodeDef& node);
bool IsZeta(const NodeDef& node);
// Return true if the op is an aggregation (e.g. Add, AddN). // Return true if the op is an aggregation (e.g. Add, AddN).
// Returns false if it could not be determined to be so. // Returns false if it could not be determined to be so.
@ -102,6 +132,9 @@ bool IsInvolution(const NodeDef& node);
// function returns true if the op commutes with all element-wise operations. // function returns true if the op commutes with all element-wise operations.
bool IsValuePreserving(const NodeDef& node); bool IsValuePreserving(const NodeDef& node);
// Returns true if we can find an opdef corresponding to the op of the node.
bool HasOpDef(const NodeDef& node);
} // end namespace grappler } // end namespace grappler
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
@ -350,15 +351,16 @@ Status DependencyOptimizer::TransitiveReduction() {
num_nodes); num_nodes);
for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
const NodeDef& node = optimized_graph_->node(node_idx); const NodeDef& node = optimized_graph_->node(node_idx);
if (ModifiesFrameInfo(node)) { if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
// Ignore nodes that modify frame info. // Ignore function nodes and nodes that modify frame info.
continue; continue;
} }
for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) { for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
const string& input = node.input(input_slot); const string& input = node.input(input_slot);
const NodeDef* input_node = node_map_->GetNode(input); const NodeDef* input_node = node_map_->GetNode(input);
if (ModifiesFrameInfo(*input_node)) { if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
// Ignore edges from nodes that modify frame info. // Ignore edges from nodes that modify frame info and from Merge nodes,
// because we cannot know which of it's input paths executes.
continue; continue;
} }
const int input_node_idx = node_to_idx_[input_node]; const int input_node_idx = node_to_idx_[input_node];
@ -375,6 +377,14 @@ Status DependencyOptimizer::TransitiveReduction() {
// of length > 1, we can drop that control dependency. // of length > 1, we can drop that control dependency.
int num_controls_removed = 0; int num_controls_removed = 0;
std::vector<int> longest_distance(num_nodes); std::vector<int> longest_distance(num_nodes);
// Map from target_index -> set of (input_slot, source_index), representing
// the control edges to remove. We sort them in reverse order by input slot,
// such that when we swap them out so we don't clobber the
// node(target).input() repeated field.
typedef std::pair<int, int> InputSlotAndSource;
std::unordered_map<
int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
control_edges_to_remove;
for (int source = 0; source < num_nodes; ++source) { for (int source = 0; source < num_nodes; ++source) {
int highest_control_target = -1; int highest_control_target = -1;
for (const auto& control_output : control_outputs[source]) { for (const auto& control_output : control_outputs[source]) {
@ -382,7 +392,7 @@ Status DependencyOptimizer::TransitiveReduction() {
highest_control_target = control_output.first; highest_control_target = control_output.first;
} }
} }
if (highest_control_target < source) { if (highest_control_target <= source) {
continue; continue;
} }
std::fill(longest_distance.begin() + source, std::fill(longest_distance.begin() + source,
@ -391,7 +401,10 @@ Status DependencyOptimizer::TransitiveReduction() {
for (int input : inputs[target]) { for (int input : inputs[target]) {
// If the input node is before source in the topo order, no path // If the input node is before source in the topo order, no path
// source -> input -> target can exits and we can skip it. // source -> input -> target can exits and we can skip it.
if (input >= source) { // Also only extend a path from the source itself or from nodes that
// have a path from source, indicated by longest_distance[input] > 0.
if (input == source ||
(input > source && longest_distance[input] > 0)) {
// If source -> input -> target is longer than the longest // If source -> input -> target is longer than the longest
// path so far from source -> target, update the longest_distance. // path so far from source -> target, update the longest_distance.
int candidate_longest_distance = longest_distance[input] + 1; int candidate_longest_distance = longest_distance[input] + 1;
@ -402,25 +415,36 @@ Status DependencyOptimizer::TransitiveReduction() {
} }
} }
// If the longest path from the source to the target of a control dependency // If the longest path from source to target of a control dependency is
// is longer than 1, there exists an alternate path, and we can eliminate // longer than 1, there exists an alternate path, and we can eliminate the
// the control dependency since it is redundant. // redundant direct control dependency.
for (const auto& control_output : control_outputs[source]) { for (const auto& control_output : control_outputs[source]) {
const int target = control_output.first; const int target = control_output.first;
if (longest_distance[target] > 1) { if (longest_distance[target] > 1) {
const int input_slot = control_output.second; const int input_slot = control_output.second;
// We modify the node inplace here. This is safe because there can control_edges_to_remove[target].emplace(input_slot, source);
// only be one control edge from a given source to a given target. VLOG(1) << "Removing edge from:\n"
const NodeDef& source_node = optimized_graph_->node(source); << optimized_graph_->node(source).DebugString() << "\n\nto:\n\n"
<< optimized_graph_->node(target).DebugString();
}
}
}
for (const auto& it : control_edges_to_remove) {
const int target = it.first;
NodeDef* target_node = optimized_graph_->mutable_node(target); NodeDef* target_node = optimized_graph_->mutable_node(target);
target_node->mutable_input()->SwapElements( for (const InputSlotAndSource& slot_and_source : it.second) {
input_slot, target_node->input_size() - 1); const int input_slot = slot_and_source.first;
const int source = slot_and_source.second;
const NodeDef& source_node = optimized_graph_->node(source);
CHECK_LT(input_slot, target_node->input_size());
target_node->mutable_input()->SwapElements(input_slot,
target_node->input_size() - 1);
node_map_->RemoveOutput(source_node.name(), target_node->name()); node_map_->RemoveOutput(source_node.name(), target_node->name());
target_node->mutable_input()->RemoveLast(); target_node->mutable_input()->RemoveLast();
++num_controls_removed; ++num_controls_removed;
} }
} }
}
VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
<< " control dependencies"; << " control dependencies";
return Status::OK(); return Status::OK();
@ -442,36 +466,27 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
nodes_to_preserve_ = item.NodesToPreserve(); nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty(); fetch_nodes_known_ = !item.fetch.empty();
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
CleanControlInputs(); CleanControlInputs();
const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1; const int num_iterations = 2;
for (int iteration = 0; iteration < num_iterations; ++iteration) { for (int iteration = 0; iteration < num_iterations; ++iteration) {
Status topo_sort_status; Status topo_sort_status;
if (opt_level_ == RewriterConfig::AGGRESSIVE) { // Perform topological sort to prepare the graph for transitive reduction.
// Prepare the graph for transitive reduction if enabled.
topo_sort_status = TopologicalSort(optimized_graph_); topo_sort_status = TopologicalSort(optimized_graph_);
}
// Set up index-based graph datastructures to speed up analysis steps below.
node_map_.reset(new NodeMap(optimized_graph_)); node_map_.reset(new NodeMap(optimized_graph_));
BuildNodeToIdx(); BuildNodeToIdx();
// Remove redundant control dependencies, iteration 1.
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
if (topo_sort_status.ok()) { if (topo_sort_status.ok()) {
// Remove redundant control dependencies.
TF_RETURN_IF_ERROR(TransitiveReduction()); TF_RETURN_IF_ERROR(TransitiveReduction());
} else { } else {
LOG(ERROR) << topo_sort_status.error_message(); LOG(ERROR) << topo_sort_status.error_message();
} }
VLOG(1) << "Graph after transitive reduction:\n"
<< optimized_graph_->DebugString();
}
// Turn nodes without non-control outputs into NoOps, prune NoOps. // Turn nodes with only control outputs into NoOps, prune NoOps.
TF_RETURN_IF_ERROR(OptimizeDependencies()); TF_RETURN_IF_ERROR(OptimizeDependencies());
VLOG(1) << "Graph after NoOp conversion & pruning:\n"
<< optimized_graph_->DebugString();
} }
VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString();
return Status::OK(); return Status::OK();
} }

View File

@ -157,7 +157,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
GrapplerItem item; GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); DependencyOptimizer optimizer;
GraphDef output; GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output); Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
@ -228,6 +228,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
// The optimization should be disabled to prevent increasing the number of // The optimization should be disabled to prevent increasing the number of
// nodes crossing device boundaries. // nodes crossing device boundaries.
TF_CHECK_OK(TopologicalSort(&item.graph));
VerifyGraphsEqual(item.graph, output, __FUNCTION__); VerifyGraphsEqual(item.graph, output, __FUNCTION__);
} }
@ -282,7 +283,7 @@ TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
GrapplerItem item; GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch.push_back("id2"); item.fetch.push_back("id2");
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); DependencyOptimizer optimizer;
GraphDef output; GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output); Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status); TF_EXPECT_OK(status);

View File

@ -60,7 +60,9 @@ std::set<string> GetOpsFormatSupported() {
"DepthwiseConv2dNativeBackpropInput", "DepthwiseConv2dNativeBackpropInput",
"DepthwiseConv2dNativeBackpropFilter", "DepthwiseConv2dNativeBackpropFilter",
"FusedBatchNorm", "FusedBatchNorm",
"FusedBatchNormV2",
"FusedBatchNormGrad", "FusedBatchNormGrad",
"FusedBatchNormGradV2",
"FusedConv2DBiasActivation", "FusedConv2DBiasActivation",
"MaxPool", "MaxPool",
"MaxPoolGrad", "MaxPoolGrad",
@ -75,52 +77,77 @@ std::set<string> GetOpsFormatAgnostic() {
std::set<string> ops_format_agnostic = {"Abs", std::set<string> ops_format_agnostic = {"Abs",
"Add", "Add",
"AddN", "AddN",
"AddV2",
"Acos", "Acos",
"Acosh", "Acosh",
"Angle", "Angle",
"ApproximateEqual",
"Asin", "Asin",
"Asinh", "Asinh",
"Atan", "Atan",
"Atan2",
"Atanh", "Atanh",
"Bitcast", "Bitcast",
"Cast", "Cast",
"Ceil", "Ceil",
"CheckNumerics", "CheckNumerics",
"Cos", "Complex",
"Cosh",
"ComplexAbs", "ComplexAbs",
"Concat", "Concat",
"ConcatV2", "ConcatV2",
"Conj", "Conj",
"Cos",
"Cosh",
"Digamma", "Digamma",
"Div",
"Elu", "Elu",
"EluGrad", "EluGrad",
"Equal",
"Erf", "Erf",
"Erfc", "Erfc",
"Exp", "Exp",
"Expm1", "Expm1",
"Floor", "Floor",
"FloorDiv",
"FloorMod",
"Greater",
"GreaterEqual",
"GuaranteeConst", "GuaranteeConst",
"Identity", "Identity",
"IdentityN",
"Igamma",
"Igammac",
"Imag", "Imag",
"Inv", "Inv",
"InvGrad", "InvGrad",
"IsFinite", "IsFinite",
"IsInf", "IsInf",
"IsNan", "IsNan",
"Less",
"LessEqual",
"Lgamma", "Lgamma",
"Log", "Log",
"LogicalAnd",
"LogicalNot",
"LogicalOr",
"Log1p", "Log1p",
"Maximum",
"Merge", "Merge",
"Minimum",
"Mod",
"Mul", "Mul",
"Neg", "Neg",
"NotEqual",
"OnesLike", "OnesLike",
"Pad", "Pad",
"PreventGradient", "PreventGradient",
"Polygamma",
"Pow",
"Real", "Real",
"RealDiv", "RealDiv",
"Reciprocal", "Reciprocal",
"ReciprocalGrad", "ReciprocalGrad",
"RefIdentity",
"Relu", "Relu",
"Relu6", "Relu6",
"Relu6Grad", "Relu6Grad",
@ -141,7 +168,8 @@ std::set<string> GetOpsFormatAgnostic() {
"SoftplusGrad", "SoftplusGrad",
"Split", "Split",
"Switch", "Switch",
"RefIdentity", "TruncateDiv",
"TruncateMod",
"RefMerge", "RefMerge",
"RefSwitch", "RefSwitch",
"Round", "Round",
@ -157,7 +185,8 @@ std::set<string> GetOpsFormatAgnostic() {
"Tan", "Tan",
"Tanh", "Tanh",
"TanhGrad", "TanhGrad",
"ZerosLike"}; "ZerosLike",
"Zeta"};
return ops_format_agnostic; return ops_format_agnostic;
} }
@ -212,6 +241,28 @@ bool IsUnaryGrad(const NodeDef& node) {
return is_unary_grad; return is_unary_grad;
} }
bool IsComparisonOp(const NodeDef& node) {
bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
IsLessEqual(node) || IsNotEqual(node);
return is_compare;
}
bool IsLogicalOp(const NodeDef& node) {
return IsLogicalAnd(node) || IsLogicalNot(node) || IsLogicalOr(node);
}
bool IsBinaryOp(const NodeDef& node) {
bool is_binary =
IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
return is_binary;
}
class GraphProcessor { class GraphProcessor {
public: public:
GraphProcessor(const VirtualPlacer& virtual_placer, GraphProcessor(const VirtualPlacer& virtual_placer,
@ -409,10 +460,11 @@ class NodeProcessor : public GraphProcessor {
virtual void UpdateAttrShape() { virtual void UpdateAttrShape() {
if (node_->attr().find("_output_shapes") != node_->attr().end()) { if (node_->attr().find("_output_shapes") != node_->attr().end()) {
for (const auto& pos : GetOutputPos()) {
auto shape = node_->mutable_attr() auto shape = node_->mutable_attr()
->at("_output_shapes") ->at("_output_shapes")
.mutable_list() .mutable_list()
->mutable_shape(0); ->mutable_shape(pos);
if (shape->dim_size() == 4) { if (shape->dim_size() == 4) {
int64 h = shape->dim(1).size(); int64 h = shape->dim(1).size();
int64 w = shape->dim(2).size(); int64 w = shape->dim(2).size();
@ -423,6 +475,7 @@ class NodeProcessor : public GraphProcessor {
} }
} }
} }
}
void UpdateAttrKSize() { void UpdateAttrKSize() {
if (node_->attr().find("ksize") != node_->attr().end()) { if (node_->attr().find("ksize") != node_->attr().end()) {
@ -603,13 +656,15 @@ class NodeProcessor : public GraphProcessor {
added_node_name = AddPrefixToNodeName(added_node_base_name, added_node_name = AddPrefixToNodeName(added_node_base_name,
kTransposeNCHWToNHWC, "-"); kTransposeNCHWToNHWC, "-");
DataType dtype; DataType dtype;
if (op == "Imag" || op == "Real" || op == "Angle" || if (IsAngle(*node_) || IsComplex(*node_) ||
op == "Conj" || op == "ComplexAbs") { IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout"));
dtype = node_->attr().at("Tout").type(); dtype = node_->attr().at("Tout").type();
} else if (op == "Bitcast") { } else if (IsBitcast(*node_)) {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "type")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "type"));
dtype = node_->attr().at("type").type(); dtype = node_->attr().at("type").type();
} else if (IsLogicalOp(*node_) || IsComparisonOp(*node_)) {
dtype = DT_BOOL;
} else { } else {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
dtype = node_->attr().at("T").type(); dtype = node_->attr().at("T").type();
@ -617,7 +672,8 @@ class NodeProcessor : public GraphProcessor {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
AddNodeTranspose( AddNodeTranspose(
added_node_name, input, const_name, dtype, added_node_name, input, const_name, dtype,
node_->attr().at("_output_shapes").list().shape(0), false); node_->attr().at("_output_shapes").list().shape(input_port),
false);
} else if (op == "DataFormatVecPermute") { } else if (op == "DataFormatVecPermute") {
added_node_name = AddPrefixToNodeName(added_node_base_name, added_node_name = AddPrefixToNodeName(added_node_base_name,
kVecPermuteNCHWToNHWC, "-"); kVecPermuteNCHWToNHWC, "-");
@ -1002,11 +1058,10 @@ class AgnosticNodeProcessor : public NodeProcessor {
if (IsConcatV1(node)) { if (IsConcatV1(node)) {
return {1}; return {1};
} }
if (IsAdd(node) || IsMul(node) || IsRealDiv(node) || if (IsBinaryOp(node) || IsUnaryGrad(node)) {
IsSquaredDifference(node) || IsSub(node)) {
return {0, 1}; return {0, 1};
} }
if (IsShapeN(node)) { if (IsShapeN(node) || IsIdentityN(node)) {
std::vector<int> pos; std::vector<int> pos;
for (int i = 0; i < node.input_size(); i++) { for (int i = 0; i < node.input_size(); i++) {
pos.push_back(i); pos.push_back(i);
@ -1207,6 +1262,40 @@ class ConcatProcessor : public AgnosticNodeProcessor {
int axis_node_pos_; int axis_node_pos_;
}; };
class IdentityNProcessor : public AgnosticNodeProcessor {
public:
explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
: AgnosticNodeProcessor(opt_cxt) {}
protected:
bool ShouldProcess() const override {
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsOnGPU();
}
std::vector<int> GetInputPos() const override {
std::vector<int> input_pos;
for (int i = 0; i < node_->input_size(); i++) {
auto input = node_map_->GetNode(node_->input(i));
int port;
ParseNodeName(node_->input(i), &port);
if (IsPortDimsFour(*input, port) &&
(IsNodeAfterNCHWToNHWC(*input) || IsNodeNCHWToNHWC(input->name()))) {
input_pos.push_back(i);
}
}
return input_pos;
}
std::set<int> GetOutputPos() const override {
std::set<int> output_pos{};
for (const auto& input_pos : GetInputPos()) {
output_pos.insert(input_pos);
}
return output_pos;
}
};
class MergeProcessor : public AgnosticNodeProcessor { class MergeProcessor : public AgnosticNodeProcessor {
public: public:
explicit MergeProcessor(const OptimizeContext& opt_cxt) explicit MergeProcessor(const OptimizeContext& opt_cxt)
@ -1371,12 +1460,15 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
bool IsInputConvertible() const { bool IsInputConvertible() const {
int input_port;
auto input = node_map_->GetNode(node_->input(0)); auto input = node_map_->GetNode(node_->input(0));
ParseNodeName(node_->input(0), &input_port);
if (IsNodeNCHWToNHWC(input->name())) { if (IsNodeNCHWToNHWC(input->name())) {
input = node_map_->GetNode(input->input(0)); input = node_map_->GetNode(input->input(0));
ParseNodeName(input->input(0), &input_port);
} }
if (input->attr().find("_output_shapes") != input->attr().end()) { if (input->attr().find("_output_shapes") != input->attr().end()) {
auto shape = input->attr().at("_output_shapes").list().shape(0); auto shape = input->attr().at("_output_shapes").list().shape(input_port);
if (shape.dim_size() != 4) { if (shape.dim_size() != 4) {
return false; return false;
} }
@ -1529,7 +1621,7 @@ class DataLayoutOptimizer : GraphProcessor {
new Conv2DBackpropFilterProcessor(opt_cxt, true)); new Conv2DBackpropFilterProcessor(opt_cxt, true));
} else if (IsDepthwiseConv2dNativeBackpropInput(*node)) { } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true)); node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
} else if (IsFusedBatchNormGradV1(*node)) { } else if (IsFusedBatchNormGrad(*node)) {
node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt)); node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
} else if (IsMaxPoolGradV1(*node)) { } else if (IsMaxPoolGradV1(*node)) {
node_processor.reset(new MaxPoolGradProcessor(opt_cxt)); node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
@ -1557,11 +1649,12 @@ class DataLayoutOptimizer : GraphProcessor {
std::unique_ptr<NodeProcessor> node_processor; std::unique_ptr<NodeProcessor> node_processor;
if (IsAddN(*node)) { if (IsAddN(*node)) {
node_processor.reset(new AddNProcessor(opt_cxt)); node_processor.reset(new AddNProcessor(opt_cxt));
} else if (IsAdd(*node) || IsMul(*node) || IsRealDiv(*node) || } else if (IsBinaryOp(*node)) {
IsSquaredDifference(*node) || IsSub(*node)) {
node_processor.reset(new BinaryOpProcessor(opt_cxt)); node_processor.reset(new BinaryOpProcessor(opt_cxt));
} else if (IsConcat(*node)) { } else if (IsConcat(*node)) {
node_processor.reset(new ConcatProcessor(opt_cxt)); node_processor.reset(new ConcatProcessor(opt_cxt));
} else if (IsIdentityN(*node)) {
node_processor.reset(new IdentityNProcessor(opt_cxt));
} else if (IsMerge(*node)) { } else if (IsMerge(*node)) {
node_processor.reset(new MergeProcessor(opt_cxt)); node_processor.reset(new MergeProcessor(opt_cxt));
} else if (IsPad(*node)) { } else if (IsPad(*node)) {

View File

@ -1066,6 +1066,48 @@ TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
"LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-1"); "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-1");
} }
TEST_F(LayoutOptimizerTest, Complex) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
auto i = ops::Identity(s.WithOpName("i"), comp);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto merge_node = node_map.GetNode("complex");
EXPECT_EQ(merge_node->input(0), "Conv2D");
EXPECT_EQ(merge_node->input(1), "Conv2D");
auto trans =
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-complex-0-0");
EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
}
TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto i = node_map.GetNode("identity_n");
EXPECT_EQ(i->input(0), "vector");
EXPECT_EQ(i->input(1), "Conv2D");
auto trans =
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
EXPECT_EQ(trans->input(0), "identity_n:1");
auto add_node = node_map.GetNode("add");
EXPECT_EQ(add_node->input(0), "identity_n");
EXPECT_EQ(add_node->input(1),
"LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -314,6 +314,11 @@ Status CudaSolver::forward_input_or_allocate_scoped_tensor(
// are sometimes inaccurate, e.g., are missing 'const' on pointers // are sometimes inaccurate, e.g., are missing 'const' on pointers
// to immutable arguments, while the actual headers have them as expected. // to immutable arguments, while the actual headers have them as expected.
// Check the actual declarations in the cusolver_api.h header file. // Check the actual declarations in the cusolver_api.h header file.
//
// NOTE: The cuSolver functions called below appear not to be threadsafe.
// so we put a global lock around the calls. Since these functions only put a
// kernel on the shared stream, it is not a big performance hit.
// TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
//============================================================================= //=============================================================================
template <typename Scalar, typename SolverFnT> template <typename Scalar, typename SolverFnT>
@ -324,6 +329,7 @@ static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
const Scalar* A, int lda, const Scalar* A, int lda,
const Scalar* beta, /* host or device pointer */ const Scalar* beta, /* host or device pointer */
const Scalar* B, int ldb, Scalar* C, int ldc) { const Scalar* B, int ldb, Scalar* C, int ldc) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type; using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n, TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
reinterpret_cast<const CudaScalar*>(alpha), reinterpret_cast<const CudaScalar*>(alpha),
@ -355,6 +361,7 @@ static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, cusolverDnHandle_t cusolver_dn_handle,
cublasFillMode_t uplo, int n, Scalar* A, int lda, cublasFillMode_t uplo, int n, Scalar* A, int lda,
int* dev_lapack_info) { int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */ /* Get amount of workspace memory required. */
int lwork; int lwork;
TF_RETURN_IF_CUSOLVER_ERROR( TF_RETURN_IF_CUSOLVER_ERROR(
@ -387,6 +394,7 @@ static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m, cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, int* dev_pivots, int n, Scalar* A, int lda, int* dev_pivots,
int* dev_lapack_info) { int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */ /* Get amount of workspace memory required. */
int lwork; int lwork;
TF_RETURN_IF_CUSOLVER_ERROR( TF_RETURN_IF_CUSOLVER_ERROR(
@ -419,9 +427,6 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
cublasOperation_t trans, int n, int nrhs, cublasOperation_t trans, int n, int nrhs,
const Scalar* A, int lda, const int* pivots, const Scalar* A, int lda, const int* pivots,
Scalar* B, int ldb, int* dev_lapack_info) { Scalar* B, int ldb, int* dev_lapack_info) {
// Note: The cuSolver functions called here appear not to be threadsafe.
// so we put a global lock around it. Since this function only puts a
// kernel on the stream, it is not a big performance hit.
mutex_lock lock(handle_map_mutex); mutex_lock lock(handle_map_mutex);
/* Launch the solver kernel. */ /* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs, TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
@ -449,6 +454,7 @@ static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m, cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, Scalar* tau, int n, Scalar* A, int lda, Scalar* tau,
int* dev_lapack_info) { int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */ /* Get amount of workspace memory required. */
int lwork; int lwork;
TF_RETURN_IF_CUSOLVER_ERROR( TF_RETURN_IF_CUSOLVER_ERROR(
@ -483,6 +489,7 @@ static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
int m, int n, int k, const Scalar* dev_a, int m, int n, int k, const Scalar* dev_a,
int lda, const Scalar* dev_tau, Scalar* dev_c, int lda, const Scalar* dev_tau, Scalar* dev_c,
int ldc, int* dev_lapack_info) { int ldc, int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */ /* Get amount of workspace memory required. */
int lwork; int lwork;
TF_RETURN_IF_CUSOLVER_ERROR( TF_RETURN_IF_CUSOLVER_ERROR(
@ -526,6 +533,7 @@ static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m, cusolverDnHandle_t cusolver_dn_handle, int m,
int n, int k, Scalar* dev_a, int lda, int n, int k, Scalar* dev_a, int lda,
const Scalar* dev_tau, int* dev_lapack_info) { const Scalar* dev_tau, int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */ /* Get amount of workspace memory required. */
int lwork; int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k, TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
@ -606,17 +614,13 @@ static inline Status GesvdImpl(
OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle, OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda, signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) { Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */ /* Get amount of workspace memory required. */
int lwork; int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork)); TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
/* Allocate device memory for workspace. */ /* Allocate device memory for workspace. */
auto dev_workspace = auto dev_workspace =
cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
// Note: The cuSolver functions called here appear not to be threadsafe.
// so we put a global lock around it. Since this function only puts a
// kernel on the stream, it is not a big performance hit.
mutex_lock lock(handle_map_mutex);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n, TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
CUDAComplex(A), lda, S, CUDAComplex(U), CUDAComplex(A), lda, S, CUDAComplex(U),
ldu, CUDAComplex(VT), ldvt, ldu, CUDAComplex(VT), ldvt,
@ -655,6 +659,7 @@ static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
int lda, int* dev_pivots, int lda, int* dev_pivots,
DeviceLapackInfo* dev_lapack_info, DeviceLapackInfo* dev_lapack_info,
int batch_size) { int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type; using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs = ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@ -689,6 +694,7 @@ static inline Status GetrsBatchedImpl(
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
const Scalar* const host_b_dev_ptrs[], int ldb, const Scalar* const host_b_dev_ptrs[], int ldb,
DeviceLapackInfo* dev_lapack_info, int batch_size) { DeviceLapackInfo* dev_lapack_info, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type; using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs = ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@ -734,6 +740,7 @@ static inline Status GetriBatchedImpl(
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type; using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs = ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@ -776,6 +783,7 @@ static inline Status MatInvBatchedImpl(
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
DeviceLapackInfo* dev_lapack_info, int batch_size) { DeviceLapackInfo* dev_lapack_info, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type; using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs = ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",

View File

@ -493,6 +493,18 @@ tf_kernel_library(
], ],
) )
tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library( tf_kernel_library(
name = "dataset_ops", name = "dataset_ops",
deps = [ deps = [
@ -526,6 +538,7 @@ tf_kernel_library(
":take_dataset_op", ":take_dataset_op",
":tensor_dataset_op", ":tensor_dataset_op",
":tensor_slice_dataset_op", ":tensor_slice_dataset_op",
":unique_dataset_op",
":zip_dataset_op", ":zip_dataset_op",
], ],
) )

View File

@ -57,7 +57,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
#define HANDLE_TYPE(T) \ #define HANDLE_TYPE(T) \
case DataTypeToEnum<T>::value: { \ case DataTypeToEnum<T>::value: { \
*output = new Dataset<T>(batch_size, row_shape, input); \ *output = new Dataset<T>(ctx, batch_size, row_shape, input); \
break; \ break; \
} }
@ -75,11 +75,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
private: private:
// TODO(mrry): Push the templated code down to the raw copying routine. // TODO(mrry): Push the templated code down to the raw copying routine.
template <class T> template <class T>
class Dataset : public DatasetBase { class Dataset : public GraphDatasetBase {
public: public:
Dataset(int64 batch_size, const PartialTensorShape& row_shape, Dataset(OpKernelContext* ctx, int64 batch_size,
const DatasetBase* input) const PartialTensorShape& row_shape, const DatasetBase* input)
: batch_size_(batch_size), row_shape_(row_shape), input_(input) { : GraphDatasetBase(ctx),
batch_size_(batch_size),
row_shape_(row_shape),
input_(input) {
input_->Ref(); input_->Ref();
output_shapes_.reserve(3); output_shapes_.reserve(3);
@ -112,6 +115,25 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
")::Dataset"); ")::Dataset");
} }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
Node* batch_size_node;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
Node* row_shape_node;
std::vector<int64> row_shape;
row_shape.reserve(
row_shape_.dims()); // not an unknown rank PartialTensorShape
for (int i = 0; i < row_shape_.dims(); i++)
row_shape.emplace_back(row_shape_.dim_size(i));
TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_node, batch_size_node, row_shape_node}, output));
return Status::OK();
}
private: private:
class Iterator : public DatasetIterator<Dataset<T>> { class Iterator : public DatasetIterator<Dataset<T>> {
public: public:
@ -242,6 +264,20 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK(); return Status::OK();
} }
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(Iterator::SaveParent(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(Iterator::RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}
private: private:
mutex mu_; mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);

View File

@ -64,20 +64,23 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
std::move(other_arguments), std::move(other_arguments),
&captured_func)); &captured_func));
*output = *output = new Dataset(ctx, input, func_, std::move(initial_state),
new Dataset(input, std::move(initial_state), std::move(captured_func), std::move(captured_func), state_types_, output_types_,
state_types_, output_types_, output_shapes_); output_shapes_);
} }
private: private:
class Dataset : public DatasetBase { class Dataset : public GraphDatasetBase {
public: public:
Dataset(const DatasetBase* input, std::vector<Tensor> initial_state, Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func, std::vector<Tensor> initial_state,
std::unique_ptr<CapturedFunction> captured_func, std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& state_types, const DataTypeVector& state_types,
const DataTypeVector& output_types, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes) const std::vector<PartialTensorShape>& output_shapes)
: input_(input), : GraphDatasetBase(ctx),
input_(input),
func_(func),
initial_state_(std::move(initial_state)), initial_state_(std::move(initial_state)),
captured_func_(std::move(captured_func)), captured_func_(std::move(captured_func)),
state_types_(state_types), state_types_(state_types),
@ -103,6 +106,45 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "ScanDatasetOp::Dataset"; } string DebugString() override { return "ScanDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
std::vector<Node*> initial_state_nodes;
initial_state_nodes.reserve(initial_state_.size());
for (const Tensor& t : initial_state_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
initial_state_nodes.emplace_back(node);
}
std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
b->BuildAttrValue(func_, &f);
AttrValue state_types;
b->BuildAttrValue(state_types_, &state_types);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {{0, input_node}},
{{1, initial_state_nodes}, {2, other_arguments}},
{{"f", f},
{"Tstate", state_types},
{"Targuments", other_arguments_types_attr}},
output));
return Status::OK();
}
private: private:
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
@ -185,6 +227,38 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
return s; return s;
} }
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
if (!state_.empty()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("state_size"), state_.size()));
for (int idx = 0; idx < state_.size(); idx++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("state[", idx, "]")), state_[idx]));
}
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (reader->Contains(full_name("state_size"))) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("state_size"), &size));
state_.resize(size);
for (int idx = 0; idx < size; idx++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("state[", idx, "]")), &state_[idx]));
}
}
return Status::OK();
}
private: private:
mutex mu_; mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
@ -192,6 +266,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
}; };
const DatasetBase* const input_; const DatasetBase* const input_;
const NameAttrList func_;
const std::vector<Tensor> initial_state_; const std::vector<Tensor> initial_state_;
const std::unique_ptr<CapturedFunction> captured_func_; const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector state_types_; const DataTypeVector state_types_;

View File

@ -0,0 +1,219 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
class UniqueDatasetOp : public UnaryDatasetOpKernel {
public:
explicit UniqueDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
errors::InvalidArgument("UniqueDataset only supports "
"inputs with a single component."));
DataType input_dtype = input->output_dtypes()[0];
OP_REQUIRES(ctx,
input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
input_dtype == DT_STRING,
errors::InvalidArgument(
"UniqueDataset only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component."));
*output = new Dataset(ctx, input);
}
private:
class Dataset : public GraphDatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input)
: GraphDatasetBase(ctx), input_(input) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIterator(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Unique")}));
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() override {
return strings::StrCat("UniqueDatasetOp::Dataset");
}
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const typename Iterator::Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
bool saw_new_value;
do {
saw_new_value = false;
out_tensors->clear();
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (*end_of_sequence) {
break;
}
DCHECK_EQ(1, out_tensors->size());
saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
} while (!saw_new_value);
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("unique_elements_size"), unique_elements_.size()));
size_t i = 0;
for (const Tensor& t : unique_elements_) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("unique_elements[", i++, "]")), t));
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
int64 num_unique_elements;
unique_elements_.clear();
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
&num_unique_elements));
for (int64 i = 0; i < num_unique_elements; ++i) {
Tensor unique_element;
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("unique_elements[", i, "]")),
&unique_element));
auto insert_result = unique_elements_.insert(unique_element);
if (!insert_result.second) {
return errors::InvalidArgument(
"Checkpoint contained two unique elements with the same "
"value.");
}
}
return Status::OK();
}
private:
struct TensorHash {
size_t operator()(const Tensor& t) const {
if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
return Hash64(t.tensor_data().data(), t.tensor_data().size());
} else {
DCHECK_EQ(DT_STRING, t.dtype());
auto flat_t = t.flat<string>();
uint64 hash = 0;
for (int64 i = 0; i < t.NumElements(); ++i) {
hash = Hash64Combine(hash, Hash64(flat_t(i)));
}
return static_cast<size_t>(hash);
}
}
};
struct TensorKeyEqual {
bool operator()(const Tensor& lhs, const Tensor& rhs) const {
if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
return false;
}
switch (lhs.dtype()) {
#define HANDLE_TYPE(T) \
case T: \
do { \
auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
for (int64 i = 0; i < lhs.NumElements(); ++i) { \
if (lhs_flat(i) != rhs_flat(i)) { \
return false; \
} \
} \
return true; \
} while (0)
HANDLE_TYPE(DT_INT32);
HANDLE_TYPE(DT_INT64);
HANDLE_TYPE(DT_STRING);
default:
LOG(FATAL) << "UniqueDataset unhandled data type: "
<< DataTypeString(lhs.dtype());
}
}
};
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
GUARDED_BY(mu_);
};
const DatasetBase* const input_;
};
};
REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
} // namespace tensorflow

View File

@ -558,6 +558,16 @@ filename: A path on the filesystem where we should cache the dataset. Note: this
will be a directory. will be a directory.
)doc"); )doc");
REGISTER_OP("UniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Creates a dataset that contains the unique elements of `input_dataset`.
)doc");
REGISTER_OP("TextLineDataset") REGISTER_OP("TextLineDataset")
.Input("filenames: string") .Input("filenames: string")
.Input("compression_type: string") .Input("compression_type: string")

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <algorithm>
#include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -327,6 +329,57 @@ Status CurlHttpRequest::SetResultBuffer(std::vector<char>* out_buffer) {
return Status::OK(); return Status::OK();
} }
Status CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) {
CHECK(buffer != nullptr);
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
direct_response_ = DirectResponseState{buffer, size, 0};
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
reinterpret_cast<void*>(this));
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION,
&CurlHttpRequest::WriteCallbackDirect);
return Status::OK();
}
size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size,
size_t nmemb, void* userdata) {
CHECK(ptr != nullptr);
auto that = reinterpret_cast<CurlHttpRequest*>(userdata);
DirectResponseState* state = &that->direct_response_;
CHECK(state->buffer_ != nullptr);
CHECK(state->bytes_transferred_ <= state->buffer_size_);
size_t curl_bytes_received = size * nmemb;
size_t user_buffer_bytes_available =
state->buffer_size_ - state->bytes_transferred_;
// The HTTP server may send a response body that is longer than what we
// expected. We must not use CHECK() for this situation, because that would
// imply a code bug (in this client code) where none exists; the violation of
// expectations would have been caused by the server, not the client. So we
// report a log warning, if an HTTP server is misbehaving.
if (curl_bytes_received > user_buffer_bytes_available) {
LOG(WARNING) << "The HTTP response body that we received is longer than we "
"requested or expected. "
<< "Total bytes requested: " << state->buffer_size_
<< " Bytes received (so far) in HTTP response body: "
<< (state->bytes_transferred_ + curl_bytes_received);
}
size_t bytes_to_copy =
std::min<size_t>(curl_bytes_received, user_buffer_bytes_available);
memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy);
state->bytes_transferred_ += bytes_to_copy;
return bytes_to_copy;
}
size_t CurlHttpRequest::GetResultBufferDirectBytesTransferred() {
CHECK(direct_response_.buffer_ != nullptr);
return direct_response_.bytes_transferred_;
}
Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity, Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity,
uint32 total) { uint32 total) {
TF_RETURN_IF_ERROR(CheckInitialized()); TF_RETURN_IF_ERROR(CheckInitialized());

View File

@ -103,6 +103,26 @@ class CurlHttpRequest : public HttpRequest {
/// read. Existing content of the vector will be cleared. /// read. Existing content of the vector will be cleared.
Status SetResultBuffer(std::vector<char>* out_buffer) override; Status SetResultBuffer(std::vector<char>* out_buffer) override;
/// \brief Specifies the buffer for receiving the response body, when the
/// caller knows the maximum size of the response body.
///
/// This method allows the caller to receive the response body without an
/// additional intermediate buffer allocation and copy. This method should
/// be called before calling Send(). After Send() has succeeded, the caller
/// should use the GetResultBufferDirectBytesTransferred() method in order
/// to learn how many bytes were transferred.
///
/// Using this method is mutually exclusive with using SetResultBuffer().
Status SetResultBufferDirect(char* buffer, size_t size) override;
/// \brief Returns the number of bytes (of the response body) that were
/// transferred, when using the SetResultBufferDirect() method. The returned
/// value will always be less than or equal to the 'size' parameter that
/// was passed to SetResultBufferDirect(). If the actual HTTP response body
/// was greater than 'size' bytes, then this transfer method will only copy
/// the first 'size' bytes, and the rest will be ignored.
size_t GetResultBufferDirectBytesTransferred() override;
/// \brief Returns the response headers of a completed request. /// \brief Returns the response headers of a completed request.
/// ///
/// If the header is not found, returns an empty string. /// If the header is not found, returns an empty string.
@ -127,6 +147,10 @@ class CurlHttpRequest : public HttpRequest {
/// A write callback in the form which can be accepted by libcurl. /// A write callback in the form which can be accepted by libcurl.
static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb, static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
void* userdata); void* userdata);
/// Processes response body content received when using SetResultBufferDirect.
static size_t WriteCallbackDirect(const void* ptr, size_t size, size_t nmemb,
void* userdata);
/// A read callback in the form which can be accepted by libcurl. /// A read callback in the form which can be accepted by libcurl.
static size_t ReadCallback(void* ptr, size_t size, size_t nmemb, static size_t ReadCallback(void* ptr, size_t size, size_t nmemb,
FILE* userdata); FILE* userdata);
@ -150,6 +174,14 @@ class CurlHttpRequest : public HttpRequest {
size_t post_body_read_ = 0; size_t post_body_read_ = 0;
std::vector<char>* response_buffer_ = nullptr; std::vector<char>* response_buffer_ = nullptr;
struct DirectResponseState {
char* buffer_;
size_t buffer_size_;
size_t bytes_transferred_;
};
DirectResponseState direct_response_ = {};
CURL* curl_ = nullptr; CURL* curl_ = nullptr;
curl_slist* curl_headers_ = nullptr; curl_slist* curl_headers_ = nullptr;
curl_slist* resolve_list_ = nullptr; curl_slist* resolve_list_ = nullptr;

View File

@ -288,6 +288,39 @@ TEST(CurlHttpRequestTest, GetRequest) {
EXPECT_EQ(200, http_request.GetResponseCode()); EXPECT_EQ(200, http_request.GetResponseCode());
} }
TEST(CurlHttpRequestTest, GetRequest_Direct) {
FakeLibCurl libcurl("get response", 200);
CurlHttpRequest http_request(&libcurl);
TF_EXPECT_OK(http_request.Init());
std::vector<char> scratch(100, 0);
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
TF_EXPECT_OK(http_request.SetRange(100, 199));
TF_EXPECT_OK(
http_request.SetResultBufferDirect(scratch.data(), scratch.capacity()));
TF_EXPECT_OK(http_request.Send());
string expected_response = "get response";
size_t response_bytes_transferred =
http_request.GetResultBufferDirectBytesTransferred();
EXPECT_EQ(response_bytes_transferred, expected_response.size());
EXPECT_EQ(
"get response",
string(scratch.begin(), scratch.begin() + response_bytes_transferred));
// Check interactions with libcurl.
EXPECT_TRUE(libcurl.is_initialized_);
EXPECT_EQ("http://www.testuri.com", libcurl.url_);
EXPECT_EQ("100-199", libcurl.range_);
EXPECT_EQ("", libcurl.custom_request_);
EXPECT_EQ(1, libcurl.headers_->size());
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl.headers_)[0]);
EXPECT_FALSE(libcurl.is_post_);
EXPECT_EQ(200, http_request.GetResponseCode());
}
TEST(CurlHttpRequestTest, GetRequest_Empty) { TEST(CurlHttpRequestTest, GetRequest_Empty) {
FakeLibCurl libcurl("", 200); FakeLibCurl libcurl("", 200);
CurlHttpRequest http_request(&libcurl); CurlHttpRequest http_request(&libcurl);

View File

@ -123,8 +123,12 @@ Status FileBlockCache::MaybeFetch(const Key& key,
case FetchState::CREATED: case FetchState::CREATED:
block->state = FetchState::FETCHING; block->state = FetchState::FETCHING;
block->mu.unlock(); // Release the lock while making the API call. block->mu.unlock(); // Release the lock while making the API call.
status.Update( block->data.clear();
block_fetcher_(key.first, key.second, block_size_, &block->data)); block->data.resize(block_size_, 0);
size_t bytes_transferred;
status.Update(block_fetcher_(key.first, key.second, block_size_,
block->data.data(), &bytes_transferred));
block->data.resize(bytes_transferred, 0);
block->mu.lock(); // Reacquire the lock immediately afterwards block->mu.lock(); // Reacquire the lock immediately afterwards
if (status.ok()) { if (status.ok()) {
downloaded_block = true; downloaded_block = true;
@ -150,15 +154,15 @@ Status FileBlockCache::MaybeFetch(const Key& key,
} }
Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
out->clear(); *bytes_transferred = 0;
if (n == 0) { if (n == 0) {
return Status::OK(); return Status::OK();
} }
if (block_size_ == 0 || max_bytes_ == 0) { if (block_size_ == 0 || max_bytes_ == 0) {
// The cache is effectively disabled, so we pass the read through to the // The cache is effectively disabled, so we pass the read through to the
// fetcher without breaking it up into blocks. // fetcher without breaking it up into blocks.
return block_fetcher_(filename, offset, n, out); return block_fetcher_(filename, offset, n, buffer, bytes_transferred);
} }
// Calculate the block-aligned start and end of the read. // Calculate the block-aligned start and end of the read.
size_t start = block_size_ * (offset / block_size_); size_t start = block_size_ * (offset / block_size_);
@ -166,6 +170,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
if (finish < offset + n) { if (finish < offset + n) {
finish += block_size_; finish += block_size_;
} }
size_t total_bytes_transferred = 0;
// Now iterate through the blocks, reading them one at a time. // Now iterate through the blocks, reading them one at a time.
for (size_t pos = start; pos < finish; pos += block_size_) { for (size_t pos = start; pos < finish; pos += block_size_) {
Key key = std::make_pair(filename, pos); Key key = std::make_pair(filename, pos);
@ -181,6 +186,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
// The requested offset is at or beyond the end of the file. This can // The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last // happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`. // block in the file, which does not extend all the way out to `offset`.
*bytes_transferred = total_bytes_transferred;
return errors::OutOfRange("EOF at offset ", offset, " in file ", filename, return errors::OutOfRange("EOF at offset ", offset, " in file ", filename,
" at position ", pos, "with data size ", " at position ", pos, "with data size ",
data.size()); data.size());
@ -196,13 +202,16 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
end -= (pos + data.size()) - (offset + n); end -= (pos + data.size()) - (offset + n);
} }
if (begin < end) { if (begin < end) {
out->insert(out->end(), begin, end); size_t bytes_to_copy = end - begin;
memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy);
total_bytes_transferred += bytes_to_copy;
} }
if (data.size() < block_size_) { if (data.size() < block_size_) {
// The block was a partial block and thus signals EOF at its upper bound. // The block was a partial block and thus signals EOF at its upper bound.
break; break;
} }
} }
*bytes_transferred = total_bytes_transferred;
return Status::OK(); return Status::OK();
} }

View File

@ -43,8 +43,9 @@ class FileBlockCache {
/// cache is constructed. The returned Status should be OK as long as the /// cache is constructed. The returned Status should be OK as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the /// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call). /// read(2) system call).
typedef std::function<Status(const string&, size_t, size_t, typedef std::function<Status(const string& filename, size_t offset,
std::vector<char>*)> size_t buffer_size, char* buffer,
size_t* bytes_transferred)>
BlockFetcher; BlockFetcher;
FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness, FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness,
@ -83,8 +84,8 @@ class FileBlockCache {
/// placed in `out`. /// placed in `out`.
/// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
/// in `out`). /// in `out`).
Status Read(const string& filename, size_t offset, size_t n, Status Read(const string& filename, size_t offset, size_t n, char* buffer,
std::vector<char>* out); size_t* bytes_transferred);
/// Remove all cached blocks for `filename`. /// Remove all cached blocks for `filename`.
void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_); void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_);

View File

@ -25,6 +25,18 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
Status ReadCache(FileBlockCache* cache, const string& filename, size_t offset,
size_t n, std::vector<char>* out) {
out->clear();
out->resize(n, 0);
size_t bytes_transferred = 0;
Status status =
cache->Read(filename, offset, n, out->data(), &bytes_transferred);
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
return status;
}
TEST(FileBlockCacheTest, PassThrough) { TEST(FileBlockCacheTest, PassThrough) {
const string want_filename = "foo/bar"; const string want_filename = "foo/bar";
const size_t want_offset = 42; const size_t want_offset = 42;
@ -32,12 +44,13 @@ TEST(FileBlockCacheTest, PassThrough) {
int calls = 0; int calls = 0;
auto fetcher = [&calls, want_filename, want_offset, want_n]( auto fetcher = [&calls, want_filename, want_offset, want_n](
const string& got_filename, size_t got_offset, const string& got_filename, size_t got_offset,
size_t got_n, std::vector<char>* out) { size_t got_n, char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_filename, want_filename);
EXPECT_EQ(got_offset, want_offset); EXPECT_EQ(got_offset, want_offset);
EXPECT_EQ(got_n, want_n); EXPECT_EQ(got_n, want_n);
calls++; calls++;
out->resize(got_n, 'x'); memset(buffer, 'x', got_n);
*bytes_transferred = got_n;
return Status::OK(); return Status::OK();
}; };
// If block_size, max_bytes, or both are zero, the cache is a pass-through. // If block_size, max_bytes, or both are zero, the cache is a pass-through.
@ -45,11 +58,11 @@ TEST(FileBlockCacheTest, PassThrough) {
FileBlockCache cache2(0, 1, 0, fetcher); FileBlockCache cache2(0, 1, 0, fetcher);
FileBlockCache cache3(0, 0, 0, fetcher); FileBlockCache cache3(0, 0, 0, fetcher);
std::vector<char> out; std::vector<char> out;
TF_EXPECT_OK(cache1.Read(want_filename, want_offset, want_n, &out)); TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 1); EXPECT_EQ(calls, 1);
TF_EXPECT_OK(cache2.Read(want_filename, want_offset, want_n, &out)); TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 2); EXPECT_EQ(calls, 2);
TF_EXPECT_OK(cache3.Read(want_filename, want_offset, want_n, &out)); TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 3); EXPECT_EQ(calls, 3);
} }
@ -63,13 +76,13 @@ TEST(FileBlockCacheTest, BlockAlignment) {
} }
// The fetcher just fetches slices of the buffer. // The fetcher just fetches slices of the buffer.
auto fetcher = [&buf](const string& filename, size_t offset, size_t n, auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
if (offset < buf.size()) { if (offset < buf.size()) {
if (offset + n > buf.size()) { size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
out->insert(out->end(), buf.begin() + offset, buf.end()); memcpy(buffer, buf.data() + offset, bytes_to_copy);
*bytes_transferred = bytes_to_copy;
} else { } else {
out->insert(out->end(), buf.begin() + offset, buf.begin() + offset + n); *bytes_transferred = 0;
}
} }
return Status::OK(); return Status::OK();
}; };
@ -80,7 +93,7 @@ TEST(FileBlockCacheTest, BlockAlignment) {
for (size_t offset = 0; offset < 10; offset++) { for (size_t offset = 0; offset < 10; offset++) {
for (size_t n = block_size - 2; n <= block_size + 2; n++) { for (size_t n = block_size - 2; n <= block_size + 2; n++) {
std::vector<char> got; std::vector<char> got;
TF_EXPECT_OK(cache.Read("", offset, n, &got)); TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got));
// Verify the size of the read. // Verify the size of the read.
if (offset + n <= size) { if (offset + n <= size) {
// Expect a full read. // Expect a full read.
@ -108,24 +121,27 @@ TEST(FileBlockCacheTest, CacheHits) {
const size_t block_size = 16; const size_t block_size = 16;
std::set<size_t> calls; std::set<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset, auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, std::vector<char>* out) { size_t n, char* buffer,
size_t* bytes_transferred) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(offset % block_size, 0);
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset; EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
calls.insert(offset); calls.insert(offset);
out->resize(n, 'x'); memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK(); return Status::OK();
}; };
const uint32 block_count = 256; const uint32 block_count = 256;
FileBlockCache cache(block_size, block_count * block_size, 0, fetcher); FileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
std::vector<char> out; std::vector<char> out;
out.resize(block_count, 0);
// The cache has space for `block_count` blocks. The loop with i = 0 should // The cache has space for `block_count` blocks. The loop with i = 0 should
// fill the cache, and the loop with i = 1 should be all cache hits. The // fill the cache, and the loop with i = 1 should be all cache hits. The
// fetcher checks that it is called once and only once for each offset (to // fetcher checks that it is called once and only once for each offset (to
// fetch the corresponding block). // fetch the corresponding block).
for (int i = 0; i < 2; i++) { for (int i = 0; i < 2; i++) {
for (int j = 0; j < block_count; j++) { for (int j = 0; j < block_count; j++) {
TF_EXPECT_OK(cache.Read("", block_size * j, block_size, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out));
} }
} }
} }
@ -138,36 +154,39 @@ TEST(FileBlockCacheTest, OutOfRange) {
bool second_block = false; bool second_block = false;
auto fetcher = [block_size, file_size, &first_block, &second_block]( auto fetcher = [block_size, file_size, &first_block, &second_block](
const string& filename, size_t offset, size_t n, const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(offset % block_size, 0);
size_t bytes_to_copy = 0;
if (offset == 0) { if (offset == 0) {
// The first block (16 bytes) of the file. // The first block (16 bytes) of the file.
out->resize(n, 'x'); memset(buffer, 'x', n);
bytes_to_copy = n;
first_block = true; first_block = true;
} else if (offset == block_size) { } else if (offset == block_size) {
// The second block (8 bytes) of the file. // The second block (8 bytes) of the file.
out->resize(file_size - block_size, 'x'); bytes_to_copy = file_size - block_size;
memset(buffer, 'x', bytes_to_copy);
second_block = true; second_block = true;
} }
*bytes_transferred = bytes_to_copy;
return Status::OK(); return Status::OK();
}; };
FileBlockCache cache(block_size, block_size, 0, fetcher); FileBlockCache cache(block_size, block_size, 0, fetcher);
std::vector<char> out; std::vector<char> out;
// Reading the first 16 bytes should be fine. // Reading the first 16 bytes should be fine.
TF_EXPECT_OK(cache.Read("", 0, block_size, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out));
EXPECT_TRUE(first_block); EXPECT_TRUE(first_block);
EXPECT_EQ(out.size(), block_size); EXPECT_EQ(out.size(), block_size);
// Reading at offset file_size + 4 will read the second block (since the read // Reading at offset file_size + 4 will read the second block (since the read
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return // at file_size + 4 = 28 will be aligned to an offset of 16) but will return
// OutOfRange because the offset is past the end of the 24-byte file. // OutOfRange because the offset is past the end of the 24-byte file.
Status status = cache.Read("", file_size + 4, 4, &out); Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
EXPECT_EQ(status.code(), error::OUT_OF_RANGE); EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
EXPECT_TRUE(second_block); EXPECT_TRUE(second_block);
EXPECT_EQ(out.size(), 0);
// Reading the second full block will return 8 bytes, from a cache hit. // Reading the second full block will return 8 bytes, from a cache hit.
second_block = false; second_block = false;
TF_EXPECT_OK(cache.Read("", block_size, block_size, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_FALSE(second_block); EXPECT_FALSE(second_block);
EXPECT_EQ(out.size(), file_size - block_size); EXPECT_EQ(out.size(), file_size - block_size);
} }
@ -178,20 +197,22 @@ TEST(FileBlockCacheTest, Inconsistent) {
const size_t block_size = 16; const size_t block_size = 16;
// This fetcher returns OK but only fills in one byte for any offset. // This fetcher returns OK but only fills in one byte for any offset.
auto fetcher = [block_size](const string& filename, size_t offset, size_t n, auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(offset % block_size, 0);
out->resize(1, 'x'); EXPECT_GE(n, 1);
memset(buffer, 'x', 1);
*bytes_transferred = 1;
return Status::OK(); return Status::OK();
}; };
FileBlockCache cache(block_size, 2 * block_size, 0, fetcher); FileBlockCache cache(block_size, 2 * block_size, 0, fetcher);
std::vector<char> out; std::vector<char> out;
// Read the second block; this should yield an OK status and a single byte. // Read the second block; this should yield an OK status and a single byte.
TF_EXPECT_OK(cache.Read("", block_size, block_size, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_EQ(out.size(), 1); EXPECT_EQ(out.size(), 1);
// Now read the first block; this should yield an INTERNAL error because we // Now read the first block; this should yield an INTERNAL error because we
// had already cached a partial block at a later position. // had already cached a partial block at a later position.
Status status = cache.Read("", 0, block_size, &out); Status status = ReadCache(&cache, "", 0, block_size, &out);
EXPECT_EQ(status.code(), error::INTERNAL); EXPECT_EQ(status.code(), error::INTERNAL);
} }
@ -199,14 +220,16 @@ TEST(FileBlockCacheTest, LRU) {
const size_t block_size = 16; const size_t block_size = 16;
std::list<size_t> calls; std::list<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset, auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, std::vector<char>* out) { size_t n, char* buffer,
size_t* bytes_transferred) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_FALSE(calls.empty()) << "at offset = " << offset; EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
if (!calls.empty()) { if (!calls.empty()) {
EXPECT_EQ(offset, calls.front()); EXPECT_EQ(offset, calls.front());
calls.pop_front(); calls.pop_front();
} }
out->resize(n, 'x'); memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK(); return Status::OK();
}; };
const uint32 block_count = 2; const uint32 block_count = 2;
@ -216,38 +239,39 @@ TEST(FileBlockCacheTest, LRU) {
// fetcher calls that the cache makes. // fetcher calls that the cache makes.
calls.push_back(0); calls.push_back(0);
// Cache miss - drains an element from `calls`. // Cache miss - drains an element from `calls`.
TF_EXPECT_OK(cache.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Cache hit - does not drain an element from `calls`. // Cache hit - does not drain an element from `calls`.
TF_EXPECT_OK(cache.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
calls.push_back(block_size); calls.push_back(block_size);
// Cache miss followed by cache hit. // Cache miss followed by cache hit.
TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
calls.push_back(2 * block_size); calls.push_back(2 * block_size);
// Cache miss followed by cache hit. Causes eviction of LRU element. // Cache miss followed by cache hit. Causes eviction of LRU element.
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// LRU element was at offset 0. Cache miss. // LRU element was at offset 0. Cache miss.
calls.push_back(0); calls.push_back(0);
TF_EXPECT_OK(cache.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Element at 2 * block_size is still in cache, and this read should update // Element at 2 * block_size is still in cache, and this read should update
// its position in the LRU list so it doesn't get evicted by the next read. // its position in the LRU list so it doesn't get evicted by the next read.
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// Element at block_size was evicted. Reading this element will also cause // Element at block_size was evicted. Reading this element will also cause
// the LRU element (at 0) to be evicted. // the LRU element (at 0) to be evicted.
calls.push_back(block_size); calls.push_back(block_size);
TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
// Element at 0 was evicted again. // Element at 0 was evicted again.
calls.push_back(0); calls.push_back(0);
TF_EXPECT_OK(cache.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
} }
TEST(FileBlockCacheTest, MaxStaleness) { TEST(FileBlockCacheTest, MaxStaleness) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
calls++; calls++;
out->resize(n, 'x'); memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK(); return Status::OK();
}; };
std::vector<char> out; std::vector<char> out;
@ -256,14 +280,14 @@ TEST(FileBlockCacheTest, MaxStaleness) {
// expected. // expected.
FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get()); FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get());
// Execute the first read to load the block. // Execute the first read to load the block.
TF_EXPECT_OK(cache1.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1); EXPECT_EQ(calls, 1);
// Now advance the clock one second at a time and redo the read. The call // Now advance the clock one second at a time and redo the read. The call
// count should advance every 3 seconds (i.e. every time the staleness is // count should advance every 3 seconds (i.e. every time the staleness is
// greater than 2). // greater than 2).
for (int i = 1; i <= 10; i++) { for (int i = 1; i <= 10; i++) {
env->SetNowSeconds(i + 1); env->SetNowSeconds(i + 1);
TF_EXPECT_OK(cache1.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1 + i / 3); EXPECT_EQ(calls, 1 + i / 3);
} }
// Now create a cache with max staleness of 0, and verify that it also works // Now create a cache with max staleness of 0, and verify that it also works
@ -272,27 +296,27 @@ TEST(FileBlockCacheTest, MaxStaleness) {
env->SetNowSeconds(0); env->SetNowSeconds(0);
FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get()); FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get());
// Execute the first read to load the block. // Execute the first read to load the block.
TF_EXPECT_OK(cache2.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1); EXPECT_EQ(calls, 1);
// Advance the clock by a huge amount and verify that the cached block is // Advance the clock by a huge amount and verify that the cached block is
// used to satisfy the read. // used to satisfy the read.
env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun. env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun.
TF_EXPECT_OK(cache2.Read("", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1); EXPECT_EQ(calls, 1);
} }
TEST(FileBlockCacheTest, RemoveFile) { TEST(FileBlockCacheTest, RemoveFile) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
calls++; calls++;
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
if (offset > 0) { if (offset > 0) {
// The first block is lower case and all subsequent blocks are upper case. // The first block is lower case and all subsequent blocks are upper case.
c = toupper(c); c = toupper(c);
} }
out->clear(); memset(buffer, c, n);
out->resize(n, c); *bytes_transferred = n;
return Status::OK(); return Status::OK();
}; };
// This cache has space for 4 blocks; we'll read from two files. // This cache has space for 4 blocks; we'll read from two files.
@ -304,41 +328,41 @@ TEST(FileBlockCacheTest, RemoveFile) {
std::vector<char> A(n, 'A'); std::vector<char> A(n, 'A');
std::vector<char> B(n, 'B'); std::vector<char> B(n, 'B');
// Fill the cache. // Fill the cache.
TF_EXPECT_OK(cache.Read("a", 0, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a); EXPECT_EQ(out, a);
EXPECT_EQ(calls, 1); EXPECT_EQ(calls, 1);
TF_EXPECT_OK(cache.Read("a", 8, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A); EXPECT_EQ(out, A);
EXPECT_EQ(calls, 2); EXPECT_EQ(calls, 2);
TF_EXPECT_OK(cache.Read("b", 0, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b); EXPECT_EQ(out, b);
EXPECT_EQ(calls, 3); EXPECT_EQ(calls, 3);
TF_EXPECT_OK(cache.Read("b", 8, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B); EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4); EXPECT_EQ(calls, 4);
// All four blocks should be in the cache now. // All four blocks should be in the cache now.
TF_EXPECT_OK(cache.Read("a", 0, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a); EXPECT_EQ(out, a);
TF_EXPECT_OK(cache.Read("a", 8, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A); EXPECT_EQ(out, A);
TF_EXPECT_OK(cache.Read("b", 0, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b); EXPECT_EQ(out, b);
TF_EXPECT_OK(cache.Read("b", 8, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B); EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4); EXPECT_EQ(calls, 4);
// Remove the blocks from "a". // Remove the blocks from "a".
cache.RemoveFile("a"); cache.RemoveFile("a");
// Both blocks from "b" should still be there. // Both blocks from "b" should still be there.
TF_EXPECT_OK(cache.Read("b", 0, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b); EXPECT_EQ(out, b);
TF_EXPECT_OK(cache.Read("b", 8, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B); EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4); EXPECT_EQ(calls, 4);
// The blocks from "a" should not be there. // The blocks from "a" should not be there.
TF_EXPECT_OK(cache.Read("a", 0, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a); EXPECT_EQ(out, a);
EXPECT_EQ(calls, 5); EXPECT_EQ(calls, 5);
TF_EXPECT_OK(cache.Read("a", 8, n, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A); EXPECT_EQ(out, A);
EXPECT_EQ(calls, 6); EXPECT_EQ(calls, 6);
} }
@ -346,10 +370,10 @@ TEST(FileBlockCacheTest, RemoveFile) {
TEST(FileBlockCacheTest, Prune) { TEST(FileBlockCacheTest, Prune) {
int calls = 0; int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n, auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
calls++; calls++;
out->clear(); memset(buffer, 'x', n);
out->resize(n, 'x'); *bytes_transferred = n;
return Status::OK(); return Status::OK();
}; };
std::vector<char> out; std::vector<char> out;
@ -360,20 +384,20 @@ TEST(FileBlockCacheTest, Prune) {
FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get()); FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get());
// Read three blocks into the cache, and advance the timestamp by one second // Read three blocks into the cache, and advance the timestamp by one second
// with each read. Start with a block of "a" at the current timestamp `now`. // with each read. Start with a block of "a" at the current timestamp `now`.
TF_EXPECT_OK(cache.Read("a", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
// Now load a block of a different file "b" at timestamp `now` + 1 // Now load a block of a different file "b" at timestamp `now` + 1
env->SetNowSeconds(now + 1); env->SetNowSeconds(now + 1);
TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
// Now load a different block of file "a" at timestamp `now` + 1. When the // Now load a different block of file "a" at timestamp `now` + 1. When the
// first block of "a" expires, this block should also be removed because it // first block of "a" expires, this block should also be removed because it
// also belongs to file "a". // also belongs to file "a".
TF_EXPECT_OK(cache.Read("a", 8, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
// Ensure that all blocks are in the cache (i.e. reads are cache hits). // Ensure that all blocks are in the cache (i.e. reads are cache hits).
EXPECT_EQ(cache.CacheSize(), 24); EXPECT_EQ(cache.CacheSize(), 24);
EXPECT_EQ(calls, 3); EXPECT_EQ(calls, 3);
TF_EXPECT_OK(cache.Read("a", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
TF_EXPECT_OK(cache.Read("a", 8, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
EXPECT_EQ(calls, 3); EXPECT_EQ(calls, 3);
// Advance the fake timestamp so that "a" becomes stale via its first block. // Advance the fake timestamp so that "a" becomes stale via its first block.
env->SetNowSeconds(now + 2); env->SetNowSeconds(now + 2);
@ -389,7 +413,7 @@ TEST(FileBlockCacheTest, Prune) {
// There should be one block left in the cache, and it should be the first // There should be one block left in the cache, and it should be the first
// block of "b". // block of "b".
EXPECT_EQ(cache.CacheSize(), 8); EXPECT_EQ(cache.CacheSize(), 8);
TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
EXPECT_EQ(calls, 3); EXPECT_EQ(calls, 3);
// Advance the fake time to `now` + 3, at which point "b" becomes stale. // Advance the fake time to `now` + 3, at which point "b" becomes stale.
env->SetNowSeconds(now + 3); env->SetNowSeconds(now + 3);
@ -409,14 +433,14 @@ TEST(FileBlockCacheTest, ParallelReads) {
const int callers = 4; const int callers = 4;
BlockingCounter counter(callers); BlockingCounter counter(callers);
auto fetcher = [&counter](const string& filename, size_t offset, size_t n, auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
counter.DecrementCount(); counter.DecrementCount();
if (!counter.WaitFor(std::chrono::seconds(10))) { if (!counter.WaitFor(std::chrono::seconds(10))) {
// This avoids having the test time out, which is harder to debug. // This avoids having the test time out, which is harder to debug.
return errors::FailedPrecondition("desired concurrency not reached"); return errors::FailedPrecondition("desired concurrency not reached");
} }
out->clear(); memset(buffer, 'x', n);
out->resize(n, 'x'); *bytes_transferred = n;
return Status::OK(); return Status::OK();
}; };
const int block_size = 8; const int block_size = 8;
@ -426,7 +450,8 @@ TEST(FileBlockCacheTest, ParallelReads) {
threads.emplace_back( threads.emplace_back(
Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() { Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() {
std::vector<char> out; std::vector<char> out;
TF_EXPECT_OK(cache.Read("a", i * block_size, block_size, &out)); TF_EXPECT_OK(
ReadCache(&cache, "a", i * block_size, block_size, &out));
std::vector<char> x(block_size, 'x'); std::vector<char> x(block_size, 'x');
EXPECT_EQ(out, x); EXPECT_EQ(out, x);
})); }));
@ -443,11 +468,12 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
Notification notification; Notification notification;
auto fetcher = [&num_requests, &notification, block_size]( auto fetcher = [&num_requests, &notification, block_size](
const string& filename, size_t offset, size_t n, const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size); EXPECT_EQ(n, block_size);
EXPECT_EQ(offset, 0); EXPECT_EQ(offset, 0);
num_requests++; num_requests++;
out->resize(n, 'x'); memset(buffer, 'x', n);
*bytes_transferred = n;
notification.Notify(); notification.Notify();
// Wait for other thread to issue read. // Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
@ -458,17 +484,16 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
std::unique_ptr<Thread> concurrent( std::unique_ptr<Thread> concurrent(
Env::Default()->StartThread({}, "concurrent", [&cache, block_size] { Env::Default()->StartThread({}, "concurrent", [&cache, block_size] {
std::vector<char> out; std::vector<char> out;
TF_EXPECT_OK(cache.Read("", 0, block_size / 2, &out)); TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2); EXPECT_EQ(out.size(), block_size / 2);
})); }));
EXPECT_TRUE(WaitForNotificationWithTimeout(&notification, 10000)) EXPECT_TRUE(WaitForNotificationWithTimeout(&notification, 10000))
<< "Timeout waiting for concurrent thread to start."; << "Timeout waiting for concurrent thread to start.";
std::vector<char> out; std::vector<char> out;
TF_EXPECT_OK(cache.Read("", block_size / 2, block_size / 2, &out)); TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2); EXPECT_EQ(out.size(), block_size / 2);
EXPECT_EQ(1, num_requests); EXPECT_EQ(1, num_requests);
} }
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -58,6 +58,10 @@ class TestHttpRequest : public HttpRequest {
Status SetResultBuffer(std::vector<char>* out_buffer) override { Status SetResultBuffer(std::vector<char>* out_buffer) override {
return Status::OK(); return Status::OK();
} }
Status SetResultBufferDirect(char* buffer, size_t size) override {
return Status::OK();
}
size_t GetResultBufferDirectBytesTransferred() override { return 0; }
string GetResponseHeader(const string& name) const override { return ""; } string GetResponseHeader(const string& name) const override { return ""; }
uint64 GetResponseCode() const override { return 0; } uint64 GetResponseCode() const override { return 0; }

View File

@ -285,11 +285,11 @@ class GcsRandomAccessFile : public RandomAccessFile {
Status Read(uint64 offset, size_t n, StringPiece* result, Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override { char* scratch) const override {
*result = StringPiece(); *result = StringPiece();
std::vector<char> out; size_t bytes_transferred;
TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out)); TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch,
std::memcpy(scratch, out.data(), std::min(out.size(), n)); &bytes_transferred));
*result = StringPiece(scratch, std::min(out.size(), n)); *result = StringPiece(scratch, bytes_transferred);
if (result->size() < n) { if (bytes_transferred < n) {
// This is not an error per se. The RandomAccessFile interface expects // This is not an error per se. The RandomAccessFile interface expects
// that Read returns OutOfRange if fewer bytes were read than requested. // that Read returns OutOfRange if fewer bytes were read than requested.
return errors::OutOfRange("EOF reached, ", result->size(), return errors::OutOfRange("EOF reached, ", result->size(),
@ -721,15 +721,17 @@ std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
std::unique_ptr<FileBlockCache> file_block_cache( std::unique_ptr<FileBlockCache> file_block_cache(
new FileBlockCache(block_size, max_bytes, max_staleness, new FileBlockCache(block_size, max_bytes, max_staleness,
[this](const string& filename, size_t offset, size_t n, [this](const string& filename, size_t offset, size_t n,
std::vector<char>* out) { char* buffer, size_t* bytes_transferred) {
return LoadBufferFromGCS(filename, offset, n, out); return LoadBufferFromGCS(filename, offset, n, buffer,
bytes_transferred);
})); }));
return file_block_cache; return file_block_cache;
} }
// A helper function to actually read the data from GCS. // A helper function to actually read the data from GCS.
Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
size_t n, std::vector<char>* out) { size_t n, char* buffer,
size_t* bytes_transferred) {
string bucket, object; string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object)); TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object));
@ -739,21 +741,23 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket,
"/", request->EscapeString(object)))); "/", request->EscapeString(object))));
TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
TF_RETURN_IF_ERROR(request->SetResultBuffer(out)); TF_RETURN_IF_ERROR(request->SetResultBufferDirect(buffer, n));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read)); request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read));
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object); bucket, "/", object);
size_t bytes_read = request->GetResultBufferDirectBytesTransferred();
*bytes_transferred = bytes_read;
VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ " VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
<< offset << " of size: " << out->size(); << offset << " of size: " << bytes_read;
if (out->size() < block_size()) { if (bytes_read < block_size()) {
// Check stat cache to see if we encountered an interrupted read. // Check stat cache to see if we encountered an interrupted read.
FileStatistics stat; FileStatistics stat;
if (stat_cache_->Lookup(filename, &stat)) { if (stat_cache_->Lookup(filename, &stat)) {
if (offset + out->size() < stat.length) { if (offset + bytes_read < stat.length) {
return errors::Internal(strings::Printf( return errors::Internal(strings::Printf(
"File contents are inconsistent for file: %s @ %lu.", "File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset)); filename.c_str(), offset));

View File

@ -177,7 +177,7 @@ class GcsFileSystem : public FileSystem {
/// Loads file contents from GCS for a given filename, offset, and length. /// Loads file contents from GCS for a given filename, offset, and length.
Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n, Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n,
std::vector<char>* out); char* buffer, size_t* bytes_transferred);
std::unique_ptr<AuthProvider> auth_provider_; std::unique_ptr<AuthProvider> auth_provider_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_; std::unique_ptr<HttpRequest::Factory> http_request_factory_;

View File

@ -101,6 +101,20 @@ class HttpRequest {
/// read. Existing content of the vector will be cleared. /// read. Existing content of the vector will be cleared.
virtual Status SetResultBuffer(std::vector<char>* out_buffer) = 0; virtual Status SetResultBuffer(std::vector<char>* out_buffer) = 0;
/// \brief Specifies the buffer for receiving the response body.
///
/// This method should be used when a caller knows the upper bound of the
/// size of the response data. The caller provides a pre-allocated buffer
/// and its size. After the Send() method is called, the
/// GetResultBufferDirectBytesTransferred() method may be used to learn to the
/// number of bytes that were transferred using this method.
virtual Status SetResultBufferDirect(char* buffer, size_t size) = 0;
/// \brief Returns the number of bytes transferred, when using
/// SetResultBufferDirect(). This method may only be used when using
/// SetResultBufferDirect().
virtual size_t GetResultBufferDirectBytesTransferred() = 0;
/// \brief Returns the response headers of a completed request. /// \brief Returns the response headers of a completed request.
/// ///
/// If the header is not found, returns an empty string. /// If the header is not found, returns an empty string.

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_ #ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_ #define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
#include <algorithm>
#include <fstream> #include <fstream>
#include <string> #include <string>
#include <vector> #include <vector>
@ -130,12 +131,25 @@ class FakeHttpRequest : public CurlHttpRequest {
buffer_ = buffer; buffer_ = buffer;
return Status::OK(); return Status::OK();
} }
Status SetResultBufferDirect(char* buffer, size_t size) override {
direct_result_buffer_ = buffer;
direct_result_buffer_size_ = size;
return Status::OK();
}
size_t GetResultBufferDirectBytesTransferred() override {
return direct_result_bytes_transferred_;
}
Status Send() override { Status Send() override {
EXPECT_EQ(expected_request_, actual_request()) EXPECT_EQ(expected_request_, actual_request())
<< "Unexpected HTTP request."; << "Unexpected HTTP request.";
if (buffer_) { if (buffer_) {
buffer_->insert(buffer_->begin(), response_.c_str(), buffer_->insert(buffer_->begin(), response_.data(),
response_.c_str() + response_.size()); response_.data() + response_.size());
} else if (direct_result_buffer_ != nullptr) {
size_t bytes_to_copy =
std::min<size_t>(direct_result_buffer_size_, response_.size());
memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
direct_result_bytes_transferred_ += bytes_to_copy;
} }
return response_status_; return response_status_;
} }
@ -178,6 +192,9 @@ class FakeHttpRequest : public CurlHttpRequest {
} }
std::vector<char>* buffer_ = nullptr; std::vector<char>* buffer_ = nullptr;
char* direct_result_buffer_ = nullptr;
size_t direct_result_buffer_size_ = 0;
size_t direct_result_bytes_transferred_ = 0;
string expected_request_; string expected_request_;
string actual_uri_; string actual_uri_;
string actual_request_; string actual_request_;

View File

@ -143,6 +143,10 @@ void Env::GetLocalTempDirectories(std::vector<string>* list) {
getenv("TMPDIR"), getenv("TMPDIR"),
getenv("TMP"), getenv("TMP"),
#if defined(__ANDROID__)
"/data/local/tmp",
#endif
// If all else fails // If all else fails
"/tmp", "/tmp",
}; };

View File

@ -0,0 +1,127 @@
/*
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
/*
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/c_api.h"
*/
import "C"
import (
"errors"
"fmt"
"runtime"
"unsafe"
"github.com/golang/protobuf/proto"
pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
)
// Encapsulates a collection of API definitions.
//
// apiDefMap represents a map from operation name to corresponding
// ApiDef proto (see
// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto
// for ApiDef proto definition).
type apiDefMap struct {
c *C.TF_ApiDefMap
}
// Creates and returns a new apiDefMap instance.
//
// oplist is and OpList proto instance (see
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
// for OpList proto definition).
func newAPIDefMap(oplist *pb.OpList) (*apiDefMap, error) {
// Create a buffer containing the serialized OpList.
opdefSerialized, err := proto.Marshal(oplist)
if err != nil {
return nil, fmt.Errorf("could not serialize OpDef for %s", oplist.String())
}
data := C.CBytes(opdefSerialized)
defer C.free(data)
opbuf := C.TF_NewBuffer()
defer C.TF_DeleteBuffer(opbuf)
opbuf.data = data
opbuf.length = C.size_t(len(opdefSerialized))
// Create ApiDefMap.
status := C.TF_NewStatus()
defer C.TF_DeleteStatus(status)
capimap := C.TF_NewApiDefMap(opbuf, status)
if C.TF_GetCode(status) != C.TF_OK {
return nil, errors.New(C.GoString(C.TF_Message(status)))
}
apimap := &apiDefMap{capimap}
runtime.SetFinalizer(
apimap,
func(a *apiDefMap) {
C.TF_DeleteApiDefMap(a.c)
})
return apimap, nil
}
// Updates apiDefMap with the overrides specified in `data`.
//
// data - ApiDef text proto.
func (m *apiDefMap) Put(data string) error {
cdata := C.CString(data)
defer C.free(unsafe.Pointer(cdata))
status := C.TF_NewStatus()
defer C.TF_DeleteStatus(status)
C.TF_ApiDefMapPut(m.c, cdata, C.size_t(len(data)), status)
if C.TF_GetCode(status) != C.TF_OK {
return errors.New(C.GoString(C.TF_Message(status)))
}
return nil
}
// Returns ApiDef proto instance for the TensorFlow operation
// named `opname`.
func (m *apiDefMap) Get(opname string) (*pb.ApiDef, error) {
cname := C.CString(opname)
defer C.free(unsafe.Pointer(cname))
status := C.TF_NewStatus()
defer C.TF_DeleteStatus(status)
apidefBuf := C.TF_ApiDefMapGet(
m.c, cname, C.size_t(len(opname)), status)
defer C.TF_DeleteBuffer(apidefBuf)
if C.TF_GetCode(status) != C.TF_OK {
return nil, errors.New(C.GoString(C.TF_Message(status)))
}
if apidefBuf == nil {
return nil, fmt.Errorf("could not find ApiDef for %s", opname)
}
var (
apidef = new(pb.ApiDef)
size = int(apidefBuf.length)
// A []byte backed by C memory.
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
data = (*[1 << 30]byte)(unsafe.Pointer(apidefBuf.data))[:size:size]
err = proto.Unmarshal(data, apidef)
)
if err != nil {
return nil, err
}
return apidef, nil
}

View File

@ -29,12 +29,18 @@ limitations under the License.
// encountered. // encountered.
package internal package internal
// #include "tensorflow/c/c_api.h" /*
#include <stdlib.h>
#include "tensorflow/c/c_api.h"
*/
import "C" import "C"
import ( import (
"fmt" "fmt"
"io" "io"
"io/ioutil"
"path"
"reflect" "reflect"
"strings" "strings"
"text/template" "text/template"
@ -47,15 +53,23 @@ import (
// GenerateFunctionsForRegisteredOps writes a Go source code file to w // GenerateFunctionsForRegisteredOps writes a Go source code file to w
// containing functions for each TensorFlow operation registered in the address // containing functions for each TensorFlow operation registered in the address
// space of the calling process. // space of the calling process.
func GenerateFunctionsForRegisteredOps(w io.Writer) error { // apidefDirs should be a contain of directories containing api_def_*.pbtxt
ops, err := registeredOps() // files to load.
func GenerateFunctionsForRegisteredOps(
w io.Writer, apidefDirs []string) error {
ops, apimap, err := registeredOps()
if err != nil { if err != nil {
return err return err
} }
return generateFunctionsForOps(w, ops) for _, dir := range apidefDirs {
if err = updateAPIDefs(apimap, dir); err != nil {
return err
}
}
return generateFunctionsForOps(w, ops, apimap)
} }
func registeredOps() (*pb.OpList, error) { func registeredOps() (*pb.OpList, *apiDefMap, error) {
buf := C.TF_GetAllOpList() buf := C.TF_GetAllOpList()
defer C.TF_DeleteBuffer(buf) defer C.TF_DeleteBuffer(buf)
var ( var (
@ -66,10 +80,31 @@ func registeredOps() (*pb.OpList, error) {
data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size] data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
err = proto.Unmarshal(data, list) err = proto.Unmarshal(data, list)
) )
return list, err if err != nil {
return nil, nil, err
}
apimap, err := newAPIDefMap(list)
return list, apimap, err
} }
func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error { func updateAPIDefs(m *apiDefMap, dir string) error {
files, err := ioutil.ReadDir(dir)
if err != nil {
return err
}
for _, file := range files {
data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
if err != nil {
return fmt.Errorf("failed to read %q: %v", file.Name(), err)
}
if err = m.Put(string(data)); err != nil {
return fmt.Errorf("failed to process %q: %v", file.Name(), err)
}
}
return nil
}
func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error {
thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath() thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
if err := tmplHeader.Execute(w, thisPackage); err != nil { if err := tmplHeader.Execute(w, thisPackage); err != nil {
return err return err
@ -83,14 +118,18 @@ func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error {
if blacklist[op.Name] { if blacklist[op.Name] {
continue continue
} }
if err := generateFunctionForOp(w, op); err != nil { apidef, err := apimap.Get(op.Name)
if err != nil {
return err
}
if err := generateFunctionForOp(w, op, apidef); err != nil {
return err return err
} }
} }
return nil return nil
} }
func generateFunctionForOp(w io.Writer, op *pb.OpDef) error { func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error {
if strings.HasPrefix(op.Name, "_") { // Internal operation if strings.HasPrefix(op.Name, "_") { // Internal operation
return nil return nil
} }
@ -112,12 +151,16 @@ func generateFunctionForOp(w io.Writer, op *pb.OpDef) error {
return nil return nil
} }
} }
if op.Summary == "" { if apidef.Summary == "" {
// Undocumented operation, perhaps a sign of not being ready to // Undocumented operation, perhaps a sign of not being ready to
// export. // export.
return nil return nil
} }
return tmplOp.Execute(w, newTmplArgs(op)) tmplArgs, err := newTmplArgs(op, apidef)
if err != nil {
return err
}
return tmplOp.Execute(w, tmplArgs)
} }
var ( var (
@ -172,7 +215,7 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in
type {{.Op.Name}}Attr func(optionalAttr) type {{.Op.Name}}Attr func(optionalAttr)
{{range .OptionalAttrs}} {{range .OptionalAttrs}}
// {{$.Op.Name}}{{CamelCase .Name}} sets the optional {{.Name}} attribute to value. // {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
{{- if .Description}} {{- if .Description}}
// //
// value: {{MakeComment .Description}} // value: {{MakeComment .Description}}
@ -180,9 +223,9 @@ type {{.Op.Name}}Attr func(optionalAttr)
// If not specified, defaults to {{StripLeadingColon .DefaultValue}} // If not specified, defaults to {{StripLeadingColon .DefaultValue}}
{{- if .HasMinimum}} {{- if .HasMinimum}}
// //
// {{if IsListAttr .}}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} // {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
{{- end}} {{- end}}
func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
return func(m optionalAttr) { return func(m optionalAttr) {
m[{{printf "%q" .Name}}] = value m[{{printf "%q" .Name}}] = value
} }
@ -192,14 +235,14 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
{{- /* Create a godoc friendly comment. */ -}} {{- /* Create a godoc friendly comment. */ -}}
// {{MakeComment .Op.Summary}} // {{MakeComment .APIDef.Summary}}
{{- with .Op.Deprecation}} {{- with .Op.Deprecation}}
// //
// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}} // DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
{{- end -}} {{- end -}}
{{- with .Op.Description}} {{- with .APIDef.Description}}
// //
// {{MakeComment .}} // {{MakeComment .}}
{{- end -}} {{- end -}}
@ -207,11 +250,11 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
{{- if .DescribeArguments}} {{- if .DescribeArguments}}
// //
// Arguments: // Arguments:
{{- range .Op.InputArg}} {{- range .InArgsReordered}}
// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}} // {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
{{- end -}} {{- end -}}
{{- range .RequiredAttrs}} {{- range .RequiredAttrs}}
// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}} // {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
@ -221,12 +264,12 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
{{- else }} {{- else }}
{{- if .DescribeOutputs}} {{- if .DescribeOutputs}}
// //
{{- if ((len .Op.OutputArg) eq 1) }} {{- if ((len .OutArgs) eq 1) }}
// Returns {{range .Op.OutputArg}}{{MakeComment .Description}}{{end}} // Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
{{- else }} {{- else }}
// Returns: // Returns:
{{- range .Op.OutputArg}} {{- range .OutArgs}}
// {{Identifier .Name}}{{if .Description}}: {{MakeComment .Description}}{{end}} // {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
@ -247,15 +290,15 @@ func {{.Op.Name}}
*/ -}} */ -}}
(scope *Scope (scope *Scope
{{- range $i, $a := .Op.InputArg}}, {{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}} {{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.Name}} {{GoType $a.Type}}{{end -}} {{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} {{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
) )
{{- /* Construct outputs: len(OpDef.OutputArg) or a *tf.Operation */ -}} {{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
{{if .Op.OutputArg -}} {{if .OutArgs -}}
({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}) ({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
{{- else -}} {{- else -}}
(o *tf.Operation) (o *tf.Operation)
{{- end }} { {{- end }} {
@ -263,7 +306,7 @@ func {{.Op.Name}}
return return
} }
{{if .HasAttrs -}} {{if .HasAttrs -}}
attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}} attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
{{if .OptionalAttrs -}} {{if .OptionalAttrs -}}
for _, a := range optional { for _, a := range optional {
a(attrs) a(attrs)
@ -272,16 +315,16 @@ func {{.Op.Name}}
{{end -}} {{end -}}
opspec := tf.OpSpec{ opspec := tf.OpSpec{
Type: {{printf "%q" .Op.Name}}, Type: {{printf "%q" .Op.Name}},
{{if .Op.InputArg -}} {{if .InArgs -}}
Input: []tf.Input{ Input: []tf.Input{
{{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}} {{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
}, },
{{- end}} {{- end}}
{{- if .HasAttrs}} {{- if .HasAttrs}}
Attrs: attrs, Attrs: attrs,
{{- end}} {{- end}}
} }
{{- if .Op.OutputArg}} {{- if .OutArgs}}
{{- if .HasListOutput}} {{- if .HasListOutput}}
op := scope.AddOperation(opspec) op := scope.AddOperation(opspec)
if scope.Err() != nil { if scope.Err() != nil {
@ -289,43 +332,105 @@ func {{.Op.Name}}
} }
var idx int var idx int
var err error var err error
{{- range $i, $a := .Op.OutputArg}} {{- range $i, $a := .OutArgs}}
{{- if IsListArg $a}} {{- if $a.IsListArg}}
if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
scope.UpdateErr({{printf "%q" $.Op.Name}}, err) scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
return return
} }
{{- else }} {{- else }}
{{Identifier .Name}} = op.Output(idx) {{Identifier .RenameTo}} = op.Output(idx)
{{- end }}{{- /* if IsListArg */}} {{- end }}{{- /* if IsListArg */}}
{{- end }}{{- /* range .Op.OutputArg */}} {{- end }}{{- /* range .OutArgs */}}
return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier .Name}}{{end}} return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
{{- else }} {{- else }}
op := scope.AddOperation(opspec) op := scope.AddOperation(opspec)
return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
{{- end }}{{- /* if .HasListOutput */}} {{- end }}{{- /* if .HasListOutput */}}
{{- else }} {{- else }}
return scope.AddOperation(opspec) return scope.AddOperation(opspec)
{{- end }}{{- /* if .Op.OutputArg */}} {{- end }}{{- /* if .OutArgs */}}
} }
`)) `))
) )
type attrWrapper struct {
op *pb.OpDef_AttrDef
api *pb.ApiDef_Attr
}
func (a *attrWrapper) Name() string { return a.api.Name }
func (a *attrWrapper) RenameTo() string { return a.api.RenameTo }
func (a *attrWrapper) Description() string { return a.api.Description }
func (a *attrWrapper) Type() string { return a.op.Type }
func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) }
func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum }
func (a *attrWrapper) Minimum() int64 { return a.op.Minimum }
func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
type argWrapper struct {
op *pb.OpDef_ArgDef
api *pb.ApiDef_Arg
}
func (a *argWrapper) Name() string { return a.api.Name }
func (a *argWrapper) RenameTo() string { return a.api.RenameTo }
func (a *argWrapper) Description() string { return a.api.Description }
func (a *argWrapper) IsListArg() bool { return isListArg(a.op) }
type tmplArgs struct { type tmplArgs struct {
Op *pb.OpDef Op *pb.OpDef
APIDef *pb.ApiDef
// Op.Attr is split into two categories // Op.Attr is split into two categories
// (1) Required: These must be specified by the client and are thus // (1) Required: These must be specified by the client and are thus
// included in the function signature. // included in the function signature.
// (2) Optional: These need not be specified (as they have default // (2) Optional: These need not be specified (as they have default
// values) and thus do not appear in the function signature. // values) and thus do not appear in the function signature.
RequiredAttrs []*pb.OpDef_AttrDef RequiredAttrs []*attrWrapper
OptionalAttrs []*pb.OpDef_AttrDef OptionalAttrs []*attrWrapper
InArgs []*argWrapper
// Input arguments ordered based on arg_order field of ApiDef.
InArgsReordered []*argWrapper
OutArgs []*argWrapper
} }
func newTmplArgs(op *pb.OpDef) *tmplArgs { func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) {
ret := tmplArgs{Op: op} ret := tmplArgs{Op: op, APIDef: apidef}
// Setup InArgs field
for i, in := range op.InputArg {
argCombined := argWrapper{op: in, api: apidef.InArg[i]}
ret.InArgs = append(ret.InArgs, &argCombined)
}
// Setup OutArgs field
for i, out := range op.OutputArg {
argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
ret.OutArgs = append(ret.OutArgs, &argCombined)
}
// Setup InArgsReordered field
for _, argName := range apidef.ArgOrder {
// Find the argument in op.InputArg
argIndex := -1
for i, in := range op.InputArg {
if in.Name == argName {
argIndex = i
break
}
}
if argIndex == -1 {
return nil, fmt.Errorf(
"couldn't find argument %s in ApiDef for op %s",
argName, op.Name)
}
argCombined := argWrapper{
op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
}
if len(op.Attr) == 0 { if len(op.Attr) == 0 {
return &ret return &ret, nil
} }
// Attributes related to the InputArg's type are inferred automatically // Attributes related to the InputArg's type are inferred automatically
// and are not exposed to the client. // and are not exposed to the client.
@ -341,28 +446,29 @@ func newTmplArgs(op *pb.OpDef) *tmplArgs {
inferred[in.NumberAttr] = true inferred[in.NumberAttr] = true
} }
} }
for _, attr := range op.Attr { for i, attr := range op.Attr {
if inferred[attr.Name] { if inferred[attr.Name] {
continue continue
} }
attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
if attr.DefaultValue == nil { if attr.DefaultValue == nil {
ret.RequiredAttrs = append(ret.RequiredAttrs, attr) ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
} else { } else {
ret.OptionalAttrs = append(ret.OptionalAttrs, attr) ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
} }
} }
return &ret return &ret, nil
} }
func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 } func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
func (a *tmplArgs) DescribeArguments() bool { func (a *tmplArgs) DescribeArguments() bool {
for _, arg := range a.Op.InputArg { for _, arg := range a.InArgs {
if arg.Description != "" { if arg.Description() != "" {
return true return true
} }
} }
for _, attr := range a.RequiredAttrs { for _, attr := range a.RequiredAttrs {
if attr.Description != "" { if attr.Description() != "" {
return true return true
} }
} }
@ -370,16 +476,16 @@ func (a *tmplArgs) DescribeArguments() bool {
} }
func (a *tmplArgs) DescribeOutputs() bool { func (a *tmplArgs) DescribeOutputs() bool {
for _, arg := range a.Op.OutputArg { for _, arg := range a.OutArgs {
if arg.Description != "" { if arg.Description() != "" {
return true return true
} }
} }
return false return false
} }
func (a *tmplArgs) HasListOutput() bool { func (a *tmplArgs) HasListOutput() bool {
for _, arg := range a.Op.OutputArg { for _, arg := range a.OutArgs {
if isListArg(arg) { if arg.IsListArg() {
return true return true
} }
} }

View File

@ -25,19 +25,44 @@ import (
pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework" pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
) )
// Creates an ApiDef based on opdef and applies overrides
// from apidefText (ApiDef text proto).
func GetAPIDef(t *testing.T, opdef *pb.OpDef, apidefText string) *pb.ApiDef {
opdefList := &pb.OpList{Op: []*pb.OpDef{opdef}}
apimap, err := newAPIDefMap(opdefList)
if err != nil {
t.Fatal(err)
}
err = apimap.Put(apidefText)
if err != nil {
t.Fatal(err)
}
apidef, err := apimap.Get(opdef.Name)
if err != nil {
t.Fatal(err)
}
return apidef
}
func TestGenerateOp(t *testing.T) { func TestGenerateOp(t *testing.T) {
// TestGenerateOp validates the generated source code for an op. // TestGenerateOp validates the generated source code for an op.
// The OpDef for the test cases are simplified forms of real ops. // The OpDef for the test cases are simplified forms of real ops.
testdata := []struct { testdata := []struct {
tag string tag string
opdef string opdef string
apidef string
wanted string wanted string
}{ }{
{ {
tag: "NoOp", tag: "NoOp",
opdef: ` opdef: `
name: "NoOp" name: "NoOp"
`,
apidef: `
op: <
graph_op_name: "NoOp"
summary: "No. Op." summary: "No. Op."
>
`, `,
wanted: ` wanted: `
// No. Op. // No. Op.
@ -80,8 +105,13 @@ attr: <
> >
> >
> >
`,
apidef: `
op: <
graph_op_name: "Add"
summary: "Returns x + y element-wise." summary: "Returns x + y element-wise."
description: "Blah blah", description: "Blah blah",
>
`, `,
wanted: ` wanted: `
// Returns x + y element-wise. // Returns x + y element-wise.
@ -122,7 +152,12 @@ attr: <
name: "DstT" name: "DstT"
type: "type" type: "type"
> >
`,
apidef: `
op: <
graph_op_name: "Cast"
summary: "Cast x of type SrcT to y of DstT." summary: "Cast x of type SrcT to y of DstT."
>
`, `,
wanted: ` wanted: `
// Cast x of type SrcT to y of DstT. // Cast x of type SrcT to y of DstT.
@ -149,12 +184,10 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) {
name: "DecodeJpeg" name: "DecodeJpeg"
input_arg: < input_arg: <
name: "contents" name: "contents"
description: "0-D. The JPEG-encoded image."
type: DT_STRING type: DT_STRING
> >
output_arg: < output_arg: <
name: "image" name: "image"
description: "3-D with shape [height, width, channels]"
type: DT_UINT8 type: DT_UINT8
> >
attr: < attr: <
@ -163,7 +196,6 @@ attr: <
default_value: < default_value: <
i: 0 i: 0
> >
description: "Number of color channels for the decoded image."
> >
attr: < attr: <
name: "fancy_upscaling" name: "fancy_upscaling"
@ -171,7 +203,6 @@ attr: <
default_value: < default_value: <
b: true b: true
> >
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
> >
attr: < attr: <
name: "acceptable_fraction" name: "acceptable_fraction"
@ -179,10 +210,34 @@ attr: <
default_value: < default_value: <
f: 1 f: 1
> >
>
`,
apidef: `
op: <
graph_op_name: "DecodeJpeg"
in_arg: <
name: "contents"
description: "0-D. The JPEG-encoded image."
>
out_arg: <
name: "image"
description: "3-D with shape [height, width, channels]"
>
attr: <
name: "channels"
description: "Number of color channels for the decoded image."
>
attr: <
name: "fancy_upscaling"
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
>
attr: <
name: "acceptable_fraction"
description: "The minimum required fraction of lines before a truncated\ninput is accepted." description: "The minimum required fraction of lines before a truncated\ninput is accepted."
> >
summary: "Decode a JPEG-encoded image to a uint8 tensor." summary: "Decode a JPEG-encoded image to a uint8 tensor."
description: "Norna dorna fjord\nkajorna\nhahaha" description: "Norna dorna fjord\nkajorna\nhahaha"
>
`, `,
wanted: ` wanted: `
// DecodeJpegAttr is an optional argument to DecodeJpeg. // DecodeJpegAttr is an optional argument to DecodeJpeg.
@ -270,7 +325,12 @@ attr: <
name: "T" name: "T"
type: "type" type: "type"
> >
`,
apidef: `
op: <
graph_op_name: "TwoOutputs"
summary: "Op that produces multiple outputs" summary: "Op that produces multiple outputs"
>
`, `,
wanted: ` wanted: `
// Op that produces multiple outputs // Op that produces multiple outputs
@ -326,8 +386,13 @@ attr: <
> >
> >
> >
`,
apidef: `
op: <
graph_op_name: "ShapeN"
summary: "Returns shape of tensors." summary: "Returns shape of tensors."
description: "Some description here." description: "Some description here."
>
`, `,
wanted: ` wanted: `
// ShapeNAttr is an optional argument to ShapeN. // ShapeNAttr is an optional argument to ShapeN.
@ -371,6 +436,102 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
} }
return output return output
} }
`,
},
{
tag: "ApiDefOverrides",
opdef: `
name: "TestOp"
input_arg: <
name: "a"
type: DT_STRING
>
input_arg: <
name: "b"
type: DT_STRING
>
output_arg: <
name: "c"
type: DT_UINT8
>
attr: <
name: "d"
type: "int"
default_value: <
i: 0
>
>
`,
apidef: `
op: <
graph_op_name: "TestOp"
in_arg: <
name: "a"
rename_to: "aa"
description: "Description for aa."
>
in_arg: <
name: "b"
rename_to: "bb"
description: "Description for bb."
>
arg_order: "b"
arg_order: "a"
out_arg: <
name: "c"
rename_to: "cc"
description: "Description for cc."
>
attr: <
name: "d"
rename_to: "dd"
description: "Description for dd."
>
summary: "Summary for TestOp."
description: "Description for TestOp."
>
`,
wanted: `
// TestOpAttr is an optional argument to TestOp.
type TestOpAttr func(optionalAttr)
// TestOpDd sets the optional dd attribute to value.
//
// value: Description for dd.
// If not specified, defaults to 0
func TestOpDd(value int64) TestOpAttr {
return func(m optionalAttr) {
m["d"] = value
}
}
// Summary for TestOp.
//
// Description for TestOp.
//
// Arguments:
// bb: Description for bb.
// aa: Description for aa.
//
// Returns Description for cc.
func TestOp(scope *Scope, bb tf.Output, aa tf.Output, optional ...TestOpAttr) (cc tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "TestOp",
Input: []tf.Input{
aa, bb,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
`, `,
}, },
} }
@ -378,11 +539,13 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
for _, test := range testdata { for _, test := range testdata {
t.Run(test.tag, func(t *testing.T) { t.Run(test.tag, func(t *testing.T) {
var opdef pb.OpDef var opdef pb.OpDef
var apidef *pb.ApiDef
var buf bytes.Buffer var buf bytes.Buffer
if err := proto.UnmarshalText(test.opdef, &opdef); err != nil { if err := proto.UnmarshalText(test.opdef, &opdef); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := generateFunctionForOp(&buf, &opdef); err != nil { apidef = GetAPIDef(t, &opdef, test.apidef)
if err := generateFunctionForOp(&buf, &opdef, apidef); err != nil {
t.Fatal(err) t.Fatal(err)
} }
got, err := format.Source(buf.Bytes()) got, err := format.Source(buf.Bytes())

View File

@ -27,6 +27,7 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/tensorflow/tensorflow/tensorflow/go/genop/internal" "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal"
) )
@ -35,6 +36,7 @@ func main() {
var ( var (
filename = flag.String("outfile", "", "File to write generated source code to.") filename = flag.String("outfile", "", "File to write generated source code to.")
header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty") header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty")
apiDefDirs = flag.String("api_def_dirs", "", "Comma-separated directories containing api_def_*.pbtxt files.")
buf bytes.Buffer buf bytes.Buffer
) )
flag.Parse() flag.Parse()
@ -51,7 +53,13 @@ func main() {
} }
os.MkdirAll(filepath.Dir(*filename), 0755) os.MkdirAll(filepath.Dir(*filename), 0755)
if err := internal.GenerateFunctionsForRegisteredOps(&buf); err != nil { apiDefDirsList := []string{}
if len(*apiDefDirs) > 0 {
apiDefDirsList = strings.Split(*apiDefDirs, ",")
}
if err := internal.GenerateFunctionsForRegisteredOps(
&buf, apiDefDirsList); err != nil {
log.Fatal(err) log.Fatal(err)
} }
formatted, err := format.Source(buf.Bytes()) formatted, err := format.Source(buf.Bytes())

View File

@ -52,6 +52,12 @@ py_library(
]), ]),
) )
py_library(
name = "common",
srcs = ["lib/common.py"],
srcs_version = "PY2AND3",
)
py_library( py_library(
name = "debug_graphs", name = "debug_graphs",
srcs = ["lib/debug_graphs.py"], srcs = ["lib/debug_graphs.py"],
@ -117,6 +123,7 @@ py_library(
srcs = ["lib/source_remote.py"], srcs = ["lib/source_remote.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":common",
":debug_service_pb2_grpc", ":debug_service_pb2_grpc",
"//tensorflow/core/debug:debug_service_proto_py", "//tensorflow/core/debug:debug_service_proto_py",
"//tensorflow/python/profiler:tfprof_logger", "//tensorflow/python/profiler:tfprof_logger",
@ -193,6 +200,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":command_parser", ":command_parser",
":common",
":debugger_cli_common", ":debugger_cli_common",
":tensor_format", ":tensor_format",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
@ -334,7 +342,11 @@ py_library(
name = "grpc_wrapper", name = "grpc_wrapper",
srcs = ["wrappers/grpc_wrapper.py"], srcs = ["wrappers/grpc_wrapper.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [":framework"], deps = [
":common",
":framework",
":source_remote",
],
) )
py_library( py_library(
@ -345,6 +357,7 @@ py_library(
":analyzer_cli", ":analyzer_cli",
":cli_shared", ":cli_shared",
":command_parser", ":command_parser",
":common",
":debug_data", ":debug_data",
":debugger_cli_common", ":debugger_cli_common",
":framework", ":framework",
@ -439,6 +452,20 @@ py_binary(
], ],
) )
py_test(
name = "common_test",
size = "small",
srcs = ["lib/common_test.py"],
srcs_version = "PY2AND3",
deps = [
":common",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:platform_test",
],
)
py_test( py_test(
name = "debug_graphs_test", name = "debug_graphs_test",
size = "small", size = "small",

View File

@ -55,7 +55,8 @@ def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig( rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True, disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF, constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options) return config_pb2.ConfigProto(graph_options=graph_options)

View File

@ -25,6 +25,7 @@ import six
from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import tensor_format from tensorflow.python.debug.cli import tensor_format
from tensorflow.python.debug.lib import common
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -214,51 +215,6 @@ def error(msg):
RL("ERROR: " + msg, COLOR_RED)]) RL("ERROR: " + msg, COLOR_RED)])
def get_graph_element_name(elem):
"""Obtain the name or string representation of a graph element.
If the graph element has the attribute "name", return name. Otherwise, return
a __str__ representation of the graph element. Certain graph elements, such as
`SparseTensor`s, do not have the attribute "name".
Args:
elem: The graph element in question.
Returns:
If the attribute 'name' is available, return the name. Otherwise, return
str(fetch).
"""
return elem.name if hasattr(elem, "name") else str(elem)
def _get_fetch_names(fetches):
"""Get a flattened list of the names in run() call fetches.
Args:
fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an
Operation or a Variable. It may also be nested lists, tuples or
dicts. See doc of `Session.run()` for more details.
Returns:
(list of str) A flattened list of fetch names from `fetches`.
"""
lines = []
if isinstance(fetches, (list, tuple)):
for fetch in fetches:
lines.extend(_get_fetch_names(fetch))
elif isinstance(fetches, dict):
for key in fetches:
lines.extend(_get_fetch_names(fetches[key]))
else:
# This ought to be a Tensor, an Operation or a Variable, for which the name
# attribute should be available. (Bottom-out condition of the recursion.)
lines.append(get_graph_element_name(fetches))
return lines
def _recommend_command(command, description, indent=2, create_link=False): def _recommend_command(command, description, indent=2, create_link=False):
"""Generate a RichTextLines object that describes a recommended command. """Generate a RichTextLines object that describes a recommended command.
@ -327,14 +283,14 @@ def get_run_start_intro(run_call_count,
(RichTextLines) Formatted intro message about the `Session.run()` call. (RichTextLines) Formatted intro message about the `Session.run()` call.
""" """
fetch_lines = _get_fetch_names(fetches) fetch_lines = common.get_flattened_names(fetches)
if not feed_dict: if not feed_dict:
feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")]
else: else:
feed_dict_lines = [] feed_dict_lines = []
for feed_key in feed_dict: for feed_key in feed_dict:
feed_key_name = get_graph_element_name(feed_key) feed_key_name = common.get_graph_element_name(feed_key)
feed_dict_line = debugger_cli_common.RichLine(" ") feed_dict_line = debugger_cli_common.RichLine(" ")
feed_dict_line += debugger_cli_common.RichLine( feed_dict_line += debugger_cli_common.RichLine(
feed_key_name, feed_key_name,
@ -446,10 +402,10 @@ def get_run_short_description(run_call_count,
description = "run #%d: " % run_call_count description = "run #%d: " % run_call_count
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
description += "1 fetch (%s); " % get_graph_element_name(fetches) description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
else: else:
# Could be (nested) list, tuple, dict or namedtuple. # Could be (nested) list, tuple, dict or namedtuple.
num_fetches = len(_get_fetch_names(fetches)) num_fetches = len(common.get_flattened_names(fetches))
if num_fetches > 1: if num_fetches > 1:
description += "%d fetches; " % num_fetches description += "%d fetches; " % num_fetches
else: else:

View File

@ -0,0 +1,87 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common values and methods for TensorFlow Debugger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
GRPC_URL_PREFIX = "grpc://"
# A key for a Session.run() call.
RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"])
def get_graph_element_name(elem):
"""Obtain the name or string representation of a graph element.
If the graph element has the attribute "name", return name. Otherwise, return
a __str__ representation of the graph element. Certain graph elements, such as
`SparseTensor`s, do not have the attribute "name".
Args:
elem: The graph element in question.
Returns:
If the attribute 'name' is available, return the name. Otherwise, return
str(fetch).
"""
return elem.name if hasattr(elem, "name") else str(elem)
def get_flattened_names(feeds_or_fetches):
"""Get a flattened list of the names in run() call feeds or fetches.
Args:
feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe
a Tensor, an Operation or a Variable. It may also be nested lists, tuples
or dicts. See doc of `Session.run()` for more details.
Returns:
(list of str) A flattened list of fetch names from `feeds_or_fetches`.
"""
lines = []
if isinstance(feeds_or_fetches, (list, tuple)):
for item in feeds_or_fetches:
lines.extend(get_flattened_names(item))
elif isinstance(feeds_or_fetches, dict):
for key in feeds_or_fetches:
lines.extend(get_flattened_names(feeds_or_fetches[key]))
else:
# This ought to be a Tensor, an Operation or a Variable, for which the name
# attribute should be available. (Bottom-out condition of the recursion.)
lines.append(get_graph_element_name(feeds_or_fetches))
return lines
def get_run_key(feed_dict, fetches):
"""Summarize the names of feeds and fetches as a RunKey JSON string.
Args:
feed_dict: The feed_dict given to the `Session.run()` call.
fetches: The fetches from the `Session.run()` call.
Returns:
A JSON Array consisting of two items. They first items is a flattened
Array of the names of the feeds. The second item is a flattened Array of
the names of the fetches.
"""
return json.dumps(RunKey(get_flattened_names(feed_dict),
get_flattened_names(fetches)))

View File

@ -0,0 +1,59 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for common values and methods of TensorFlow Debugger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from tensorflow.python.debug.lib import common
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class CommonTest(test_util.TensorFlowTestCase):
def testOnFeedOneFetch(self):
a = constant_op.constant(10.0, name="a")
b = constant_op.constant(20.0, name="b")
run_key = common.get_run_key({"a": a}, [b])
loaded = json.loads(run_key)
self.assertItemsEqual(["a:0"], loaded[0])
self.assertItemsEqual(["b:0"], loaded[1])
def testGetRunKeyFlat(self):
a = constant_op.constant(10.0, name="a")
b = constant_op.constant(20.0, name="b")
run_key = common.get_run_key({"a": a}, [a, b])
loaded = json.loads(run_key)
self.assertItemsEqual(["a:0"], loaded[0])
self.assertItemsEqual(["a:0", "b:0"], loaded[1])
def testGetRunKeyNestedFetches(self):
a = constant_op.constant(10.0, name="a")
b = constant_op.constant(20.0, name="b")
c = constant_op.constant(30.0, name="c")
d = constant_op.constant(30.0, name="d")
run_key = common.get_run_key(
{}, {"set1": [a, b], "set2": {"c": c, "d": d}})
loaded = json.loads(run_key)
self.assertItemsEqual([], loaded[0])
self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
if __name__ == "__main__":
googletest.main()

View File

@ -310,7 +310,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
op_log_proto.id_to_string) op_log_proto.id_to_string)
raise ValueError( raise ValueError(
"Op '%s' does not exist in the tracebacks received by the debug " "Op '%s' does not exist in the tracebacks received by the debug "
"server.") "server." % op_name)
def query_origin_stack(self): def query_origin_stack(self):
"""Query the stack of the origin of the execution call. """Query the stack of the origin of the execution call.
@ -348,6 +348,9 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
Raises: Raises:
ValueError: If no source file is found at the given file_path. ValueError: If no source file is found at the given file_path.
""" """
if not self._source_files:
raise ValueError(
"This debug server has not received any source file contents yet.")
for source_file_proto in self._source_files.source_files: for source_file_proto in self._source_files.source_files:
if source_file_proto.file_path == file_path: if source_file_proto.file_path == file_path:
return source_file_proto.lines[lineno - 1] return source_file_proto.lines[lineno - 1]

View File

@ -248,7 +248,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
self.assertEqual( self.assertEqual(
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0])) 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testTensorBoardDebugHooWorks(self): def testTensorBoardDebugHookWorks(self):
u = variables.Variable(2.1, name="u") u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v") v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w") w = math_ops.multiply(u, v, name="w")
@ -261,8 +261,37 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
["localhost:%d" % self._server_port]) ["localhost:%d" % self._server_port])
sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
# Activate watch point on some a tensor before calling sess.run().
self._server.request_watch("u/read", 0, "DebugIdentity")
self.assertAllClose(42.0, sess.run(w)) self.assertAllClose(42.0, sess.run(w))
# self.assertAllClose(42.0, sess.run(w))
dump = debug_data.DebugDumpDir(self._dump_root)
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
# Check that the server has received the stack trace.
self.assertTrue(self._server.query_op_traceback("u"))
self.assertTrue(self._server.query_op_traceback("u/read"))
self.assertTrue(self._server.query_op_traceback("v"))
self.assertTrue(self._server.query_op_traceback("v/read"))
self.assertTrue(self._server.query_op_traceback("w"))
# Check that the server has received the python file content.
# Query an arbitrary line to make sure that is the case.
with open(__file__, "rt") as this_source_file:
first_line = this_source_file.readline().strip()
self.assertEqual(
first_line, self._server.query_source_file_line(__file__, 1))
self._server.clear_data()
# Call sess.run() again, and verify that this time the traceback and source
# code is not sent, because the graph version is not newer.
self.assertAllClose(42.0, sess.run(w))
with self.assertRaises(ValueError):
self._server.query_op_traceback("delta_1")
with self.assertRaises(ValueError):
self._server.query_source_file_line(__file__, 1)
def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self): def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
hooks.GrpcDebugHook(["grpc://foo:42424"]) hooks.GrpcDebugHook(["grpc://foo:42424"])
hooks.GrpcDebugHook(["foo:42424"]) hooks.GrpcDebugHook(["foo:42424"])
@ -748,6 +777,28 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
# to disable the breakpoint at delta:0:DebugIdentity. # to disable the breakpoint at delta:0:DebugIdentity.
self.assertSetEqual(set(), self._server_1.breakpoints) self.assertSetEqual(set(), self._server_1.breakpoints)
if i == 0:
# Check that the server has received the stack trace.
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
self.assertTrue(self._server_1.query_op_traceback("delta_2"))
self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
# Check that the server has received the python file content.
# Query an arbitrary line to make sure that is the case.
with open(__file__, "rt") as this_source_file:
first_line = this_source_file.readline().strip()
self.assertEqual(
first_line, self._server_1.query_source_file_line(__file__, 1))
else:
# In later Session.run() calls, the traceback shouldn't have been sent
# because it is already sent in the 1st call. So calling
# query_op_traceback() should lead to an exception, because the test
# debug server clears the data at the beginning of every iteration.
with self.assertRaises(ValueError):
self._server_1.query_op_traceback("delta_1")
with self.assertRaises(ValueError):
self._server_1.query_source_file_line(__file__, 1)
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self): def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess: with session.Session() as sess:
v = variables.Variable(50.0, name="v") v = variables.Variable(50.0, name="v")

View File

@ -24,6 +24,7 @@ import grpc
from tensorflow.core.debug import debug_service_pb2 from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.protobuf import debug_pb2 from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.debug.lib import source_utils from tensorflow.python.debug.lib import source_utils
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
@ -130,6 +131,11 @@ def _send_call_tracebacks(destinations,
""" """
if not isinstance(destinations, list): if not isinstance(destinations, list):
destinations = [destinations] destinations = [destinations]
# Strip grpc:// prefix, if any is present.
destinations = [
dest[len(common.GRPC_URL_PREFIX):]
if dest.startswith(common.GRPC_URL_PREFIX) else dest
for dest in destinations]
call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
if is_eager_execution if is_eager_execution

View File

@ -17,15 +17,55 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
import traceback
# Google-internal import(s). # Google-internal import(s).
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import framework
def publish_traceback(debug_server_urls,
graph,
feed_dict,
fetches,
old_graph_version):
"""Publish traceback and source code if graph version is new.
`graph.version` is compared with `old_graph_version`. If the former is higher
(i.e., newer), the graph traceback and the associated source code is sent to
the debug server at the specified gRPC URLs.
Args:
debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
debug server URLs.
graph: A Python `tf.Graph` object.
feed_dict: Feed dictionary given to the `Session.run()` call.
fetches: Fetches from the `Session.run()` call.
old_graph_version: Old graph version to compare to.
Returns:
If `graph.version > old_graph_version`, the new graph version as an `int`.
Else, the `old_graph_version` is returned.
"""
# TODO(cais): Consider moving this back to the top, after grpc becomes a
# pip dependency of tensorflow or tf_debug.
# pylint:disable=g-import-not-at-top
from tensorflow.python.debug.lib import source_remote
# pylint:enable=g-import-not-at-top
if graph.version > old_graph_version:
run_key = common.get_run_key(feed_dict, fetches)
source_remote.send_graph_tracebacks(
debug_server_urls, run_key, traceback.extract_stack(), graph,
send_source=True)
return graph.version
else:
return old_graph_version
class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
"""Debug Session wrapper that send debug data to gRPC stream(s).""" """Debug Session wrapper that send debug data to gRPC stream(s)."""
_GRPC_URL_PREFIX = "grpc://"
def __init__(self, def __init__(self,
sess, sess,
grpc_debug_server_addresses, grpc_debug_server_addresses,
@ -94,8 +134,8 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
return self._grpc_debug_server_urls return self._grpc_debug_server_urls
def _normalize_grpc_url(self, address): def _normalize_grpc_url(self, address):
return (self._GRPC_URL_PREFIX + address return (common.GRPC_URL_PREFIX + address
if not address.startswith(self._GRPC_URL_PREFIX) else address) if not address.startswith(common.GRPC_URL_PREFIX) else address)
class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
@ -126,3 +166,25 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
watch_fn=_gated_grpc_watch_fn, watch_fn=_gated_grpc_watch_fn,
thread_name_filter=thread_name_filter, thread_name_filter=thread_name_filter,
log_usage=log_usage) log_usage=log_usage)
# Keeps track of the latest version of Python graph object that has been
# sent to the debug servers.
self._sent_graph_version = -sys.maxint
def run(self,
fetches,
feed_dict=None,
options=None,
run_metadata=None,
callable_runner=None,
callable_runner_args=None):
self._sent_graph_version = publish_traceback(
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
self._sent_graph_version)
return super(TensorBoardDebugWrapperSession, self).run(
fetches,
feed_dict=feed_dict,
options=options,
run_metadata=run_metadata,
callable_runner=callable_runner,
callable_runner_args=callable_runner_args)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug.lib import debug_utils from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import stepper from tensorflow.python.debug.lib import stepper
@ -331,3 +333,13 @@ class TensorBoardDebugHook(GrpcDebugHook):
watch_fn=_gated_grpc_watch_fn, watch_fn=_gated_grpc_watch_fn,
thread_name_filter=thread_name_filter, thread_name_filter=thread_name_filter,
log_usage=log_usage) log_usage=log_usage)
self._grpc_debug_server_addresses = grpc_debug_server_addresses
self._sent_graph_version = -sys.maxint
def before_run(self, run_context):
self._sent_graph_version = grpc_wrapper.publish_traceback(
self._grpc_debug_server_addresses, run_context.session.graph,
run_context.original_args.feed_dict, run_context.original_args.fetches,
self._sent_graph_version)
return super(TensorBoardDebugHook, self).before_run(run_context)

View File

@ -31,6 +31,7 @@ from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import profile_analyzer_cli from tensorflow.python.debug.cli import profile_analyzer_cli
from tensorflow.python.debug.cli import stepper_cli from tensorflow.python.debug.cli import stepper_cli
from tensorflow.python.debug.cli import ui_factory from tensorflow.python.debug.cli import ui_factory
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import framework from tensorflow.python.debug.wrappers import framework
@ -464,7 +465,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
feed_key = None feed_key = None
feed_value = None feed_value = None
for key in self._feed_dict: for key in self._feed_dict:
key_name = cli_shared.get_graph_element_name(key) key_name = common.get_graph_element_name(key)
if key_name == tensor_name: if key_name == tensor_name:
feed_key = key_name feed_key = key_name
feed_value = self._feed_dict[key] feed_value = self._feed_dict[key]
@ -561,7 +562,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
list(self._tensor_filters.keys())) list(self._tensor_filters.keys()))
if self._feed_dict: if self._feed_dict:
# Register tab completion for feed_dict keys. # Register tab completion for feed_dict keys.
feed_keys = [cli_shared.get_graph_element_name(key) feed_keys = [common.get_graph_element_name(key)
for key in self._feed_dict.keys()] for key in self._feed_dict.keys()]
curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys) curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)

View File

@ -495,14 +495,13 @@ class GraphModeFunction(object):
def _get_defun_inputs(args): def _get_defun_inputs(args):
"""Maps the inputs args to graph inputs.""" """Maps the inputs args to graph inputs."""
ret = [] ret = []
for a in args: flat_args = nest.flatten(args)
for a in flat_args:
if isinstance(a, ops.Tensor): if isinstance(a, ops.Tensor):
ret.append(graph_placeholder(a.dtype, a.shape)) ret.append(graph_placeholder(a.dtype, a.shape))
elif type(a) in (tuple, list):
ret.append(_get_defun_inputs(a))
else: else:
ret.append(a) ret.append(a)
return tuple(ret) if type(args) is tuple else ret return nest.pack_sequence_as(args, ret)
def _defun_internal(name, func, args, kwds): def _defun_internal(name, func, args, kwds):
@ -582,8 +581,10 @@ def _cache_key(x):
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
if isinstance(x, np.ndarray): if isinstance(x, np.ndarray):
return ("array", x.shape, tuple(x.reshape(-1))) return ("array", x.shape, tuple(x.reshape(-1)))
if type(x) in (list, tuple): if isinstance(x, (list, tuple)):
return tuple([_cache_key(a) for a in x]) return tuple([_cache_key(a) for a in x])
if isinstance(x, dict):
return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items())
return x return x

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import function
@ -57,6 +59,20 @@ class FunctionTest(test.TestCase):
out = sq(t) out = sq(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsGraphMode(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@function.defun
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = a_times_b(pair({'a': t}, {'b': t}))
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testGraphModeWithGradients(self): def testGraphModeWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v') v = resource_variable_ops.ResourceVariable(1.0, name='v')
@ -83,6 +99,22 @@ class FunctionTest(test.TestCase):
out = sq_op(t) out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
inputs = pair({'a': t}, {'b': t})
sq_op = function.make_defun_op(a_times_b, inputs)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(inputs)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputDefunOpGraphMode(self): def testNestedOutputDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul) matmul = function.defun(math_ops.matmul)

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <array>
#include "tensorflow/python/lib/core/bfloat16.h" #include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/numeric_types.h"
@ -477,8 +479,61 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
return true; return true;
} }
template <typename InType, typename OutType, typename Functor>
void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
void* data) {
const char* i0 = args[0];
const char* i1 = args[1];
char* o = args[2];
for (npy_intp k = 0; k < *dimensions; k++) {
InType x = *reinterpret_cast<const InType*>(i0);
InType y = *reinterpret_cast<const InType*>(i1);
*reinterpret_cast<OutType*>(o) = Functor()(x, y);
i0 += steps[0];
i1 += steps[1];
o += steps[2];
}
}
template <typename Functor>
void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
void* data) {
BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
}
struct Bfloat16EqFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
};
struct Bfloat16NeFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
};
struct Bfloat16LtFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
};
struct Bfloat16GtFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
};
struct Bfloat16LeFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
};
struct Bfloat16GeFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
};
// Initializes the module. // Initializes the module.
bool Initialize() { bool Initialize() {
// It's critical to import umath to avoid crash in open source build.
import_umath1(false);
Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
if (!numpy_str) {
return false;
}
Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
if (!numpy) {
return false;
}
// We hit a mysterious crash if we haven't initialized numpy before this: // We hit a mysterious crash if we haven't initialized numpy before this:
PyBfloat16_Type.tp_base = &PyGenericArrType_Type; PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
@ -536,6 +591,57 @@ bool Initialize() {
/*cast_is_safe=*/true)) { /*cast_is_safe=*/true)) {
return false; return false;
} }
// Register ufuncs
auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
const std::array<int, 3>& types) {
Safe_PyObjectPtr ufunc_obj =
make_safe(PyObject_GetAttrString(numpy.get(), name));
if (!ufunc_obj) {
return false;
}
PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
if (types.size() != ufunc->nargs) {
PyErr_Format(PyExc_AssertionError,
"ufunc %s takes %d arguments, loop takes %lu", name,
ufunc->nargs, types.size());
return false;
}
if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
const_cast<int*>(types.data()),
nullptr) < 0) {
return false;
}
return true;
};
// Comparisons
const std::array<int, 3> compare_types = {npy_bfloat16_, npy_bfloat16_,
NPY_BOOL};
if (!register_ufunc("equal", CompareUFunc<Bfloat16EqFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("not_equal", CompareUFunc<Bfloat16NeFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("less", CompareUFunc<Bfloat16LtFunctor>, compare_types)) {
return false;
}
if (!register_ufunc("greater", CompareUFunc<Bfloat16GtFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("less_equal", CompareUFunc<Bfloat16LeFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("greater_equal", CompareUFunc<Bfloat16GeFunctor>,
compare_types)) {
return false;
}
return true; return true;
} }

View File

@ -172,6 +172,24 @@ class Bfloat16NumPyTest(test.TestCase):
self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x)) self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
self.assertAllEqual(x, x) self.assertAllEqual(x, x)
self.assertAllClose(x, x) self.assertAllClose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
x = np.array([401408, 7, -32], dtype=np.float32)
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
self.assertAllEqual(x == y, bx == by)
self.assertAllEqual(x != y, bx != by)
self.assertAllEqual(x < y, bx < by)
self.assertAllEqual(x > y, bx > by)
self.assertAllEqual(x <= y, bx <= by)
self.assertAllEqual(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b))
def testCasts(self): def testCasts(self):
for dtype in [ for dtype in [

View File

@ -32,6 +32,7 @@ limitations under the License.
#include <Python.h> #include <Python.h>
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
namespace tensorflow { namespace tensorflow {

View File

@ -25,6 +25,7 @@ py_library(
":main_op", ":main_op",
":signature_constants", ":signature_constants",
":signature_def_utils", ":signature_def_utils",
":simple_save",
":tag_constants", ":tag_constants",
":utils", ":utils",
"//tensorflow/python:util", "//tensorflow/python:util",
@ -89,6 +90,23 @@ py_library(
], ],
) )
py_library(
name = "simple_save",
srcs = [
"simple_save.py",
],
srcs_version = "PY2AND3",
deps = [
":builder",
":signature_constants",
":signature_def_utils",
":tag_constants",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:lib",
"//tensorflow/python:util",
],
)
py_library( py_library(
name = "main_op", name = "main_op",
srcs = [ srcs = [
@ -198,6 +216,22 @@ py_test(
], ],
) )
py_test(
name = "simple_save_test",
size = "small",
srcs = ["simple_save_test.py"],
srcs_version = "PY2AND3",
deps = [
":loader",
":signature_constants",
":simple_save",
":tag_constants",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo. # Google-internal targets. These must be at the end for syncrepo.

View File

@ -30,6 +30,9 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils from tensorflow.python.saved_model import utils
# pylint: enable=unused-import # pylint: enable=unused-import
# pylint: disable=wildcard-import
from tensorflow.python.saved_model.simple_save import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented
@ -41,6 +44,7 @@ _allowed_symbols = [
"main_op", "main_op",
"signature_constants", "signature_constants",
"signature_def_utils", "signature_def_utils",
"simple_save",
"tag_constants", "tag_constants",
"utils", "utils",
] ]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""SavedModel utility functions.""" """SavedModel simple save functionality."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -39,7 +39,7 @@ def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None):
to configure a SavedModel, this method has a few practical implications: to configure a SavedModel, this method has a few practical implications:
- It will be treated as a graph for inference / serving (i.e. uses the tag - It will be treated as a graph for inference / serving (i.e. uses the tag
`tag_constants.SERVING`) `tag_constants.SERVING`)
- The saved model will load in TensorFlow Serving and supports the - The SavedModel will load in TensorFlow Serving and supports the
[Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto).
To use the Classify, Regress, or MultiInference APIs, please To use the Classify, Regress, or MultiInference APIs, please
use either use either

Some files were not shown because too many files have changed in this diff Show More