Merge commit for internal changes
This commit is contained in:
commit
15283d90e7
@ -42,6 +42,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
@ -73,6 +74,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/ops/while_loop.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#endif
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -2618,4 +2619,54 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
|
||||
output_values, target_names, nullptr, status);
|
||||
}
|
||||
|
||||
TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
|
||||
tensorflow::OpList op_list;
|
||||
if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
|
||||
status->status = InvalidArgument("Unparseable OpList");
|
||||
return nullptr;
|
||||
}
|
||||
status->status = Status::OK();
|
||||
return new TF_ApiDefMap(op_list);
|
||||
}
|
||||
|
||||
void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
|
||||
|
||||
void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
|
||||
size_t text_len, TF_Status* status) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"ApiDefMap is not supported in Android.");
|
||||
#else
|
||||
mutex_lock l(api_def_map->lock);
|
||||
if (api_def_map->update_docs_called) {
|
||||
status->status = FailedPrecondition(
|
||||
"TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
|
||||
"called.");
|
||||
return;
|
||||
}
|
||||
string api_def_text(text, text_len);
|
||||
status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
|
||||
#endif // __ANDROID__
|
||||
}
|
||||
|
||||
TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
|
||||
size_t name_len, TF_Status* status) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"ApiDefMap is not supported in Android.");
|
||||
return nullptr;
|
||||
#else
|
||||
mutex_lock l(api_def_map->lock);
|
||||
if (!api_def_map->update_docs_called) {
|
||||
api_def_map->api_def_map.UpdateDocs();
|
||||
api_def_map->update_docs_called = true;
|
||||
}
|
||||
string name_str(name, name_len);
|
||||
const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
|
||||
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(*api_def, ret);
|
||||
return ret;
|
||||
#endif // __ANDROID__
|
||||
}
|
||||
} // end extern "C"
|
||||
|
@ -1518,6 +1518,49 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
|
||||
// in this address space.
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList();
|
||||
|
||||
// TF_ApiDefMap encapsulates a collection of API definitions for an operation.
|
||||
//
|
||||
// This object maps the name of a TensorFlow operation to a description of the
|
||||
// API to generate for it, as defined by the ApiDef protocol buffer (
|
||||
// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto)
|
||||
//
|
||||
// The ApiDef messages are typically used to generate convenience wrapper
|
||||
// functions for TensorFlow operations in various language bindings.
|
||||
typedef struct TF_ApiDefMap TF_ApiDefMap;
|
||||
|
||||
// Creates a new TF_ApiDefMap instance.
|
||||
//
|
||||
// Params:
|
||||
// op_list_buffer - TF_Buffer instance containing serialized OpList
|
||||
// protocol buffer. (See
|
||||
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
|
||||
// for the OpList proto definition).
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer,
|
||||
TF_Status* status);
|
||||
|
||||
// Deallocates a TF_ApiDefMap.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap);
|
||||
|
||||
// Add ApiDefs to the map.
|
||||
//
|
||||
// `text` corresponds to a text representation of an ApiDefs protocol message.
|
||||
// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto).
|
||||
//
|
||||
// The provided ApiDefs will be merged with existing ones in the map, with
|
||||
// precedence given to the newly added version in case of conflicts with
|
||||
// previous calls to TF_ApiDefMapPut.
|
||||
TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map,
|
||||
const char* text, size_t text_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a serialized ApiDef protocol buffer for the TensorFlow operation
|
||||
// named `name`.
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map,
|
||||
const char* name,
|
||||
size_t name_len,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -24,6 +24,9 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#ifndef __ANDROID__
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -158,6 +161,22 @@ struct TF_Function {
|
||||
tensorflow::FunctionDef fdef;
|
||||
};
|
||||
|
||||
struct TF_ApiDefMap {
|
||||
explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
|
||||
:
|
||||
#ifndef __ANDROID__
|
||||
api_def_map(op_list),
|
||||
#endif
|
||||
update_docs_called(false) {
|
||||
}
|
||||
|
||||
#ifndef __ANDROID__
|
||||
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
|
||||
#endif
|
||||
bool update_docs_called GUARDED_BY(lock);
|
||||
tensorflow::mutex lock;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) {
|
||||
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
|
||||
}
|
||||
|
||||
TEST(TestApiDef, TestCreateApiDef) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||
status = TF_NewStatus();
|
||||
auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
string op_name = "TestCApi";
|
||||
status = TF_NewStatus();
|
||||
auto* api_def_buf =
|
||||
TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tensorflow::ApiDef api_def;
|
||||
EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
|
||||
EXPECT_EQ(op_name, api_def.graph_op_name());
|
||||
EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
|
||||
|
||||
TF_DeleteBuffer(api_def_buf);
|
||||
TF_DeleteApiDefMap(api_def_map);
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Library* lib =
|
||||
TF_LoadLibrary("tensorflow/c/test_op.so", status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TF_Buffer op_list_buf = TF_GetOpList(lib);
|
||||
status = TF_NewStatus();
|
||||
auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
string api_def_overwrites = R"(op: <
|
||||
graph_op_name: "TestCApi"
|
||||
summary: "New summary"
|
||||
>
|
||||
)";
|
||||
status = TF_NewStatus();
|
||||
TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
|
||||
api_def_overwrites.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
string op_name = "TestCApi";
|
||||
status = TF_NewStatus();
|
||||
auto* api_def_buf =
|
||||
TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tensorflow::ApiDef api_def;
|
||||
EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
|
||||
EXPECT_EQ(op_name, api_def.graph_op_name());
|
||||
EXPECT_EQ("New summary", api_def.summary());
|
||||
|
||||
TF_DeleteBuffer(api_def_buf);
|
||||
TF_DeleteApiDefMap(api_def_map);
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
#undef EXPECT_TF_META
|
||||
|
||||
} // namespace
|
||||
|
@ -33,7 +33,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
}) + ["//tensorflow/core:gpu_runtime"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
@ -55,6 +55,10 @@ tf_cuda_library(
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_test",
|
||||
srcs = ["c_api_test.cc"],
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/runtime.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -167,18 +168,6 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
if (is_same_device) {
|
||||
return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
|
||||
}
|
||||
const bool src_cpu = IsCPU(srcd);
|
||||
if (src_cpu == dst_cpu) {
|
||||
TF_SetStatus(
|
||||
status, TF_INVALID_ARGUMENT,
|
||||
tensorflow::strings::StrCat(
|
||||
"TFE_TensorHandleCopyToDevice requires either the source "
|
||||
"TFE_TensorHandle be on or the destination device be on CPU "
|
||||
"or be the same (they are ",
|
||||
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Tensor* src = &(h->t);
|
||||
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
|
||||
TF_SetStatus(
|
||||
@ -189,26 +178,19 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
if (src_cpu) {
|
||||
tensorflow::Tensor dst(
|
||||
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
|
||||
src->shape());
|
||||
if (src->shape().num_elements() == 0) {
|
||||
return new TFE_TensorHandle(dst, dstd);
|
||||
}
|
||||
tensorflow::Notification n;
|
||||
dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice(
|
||||
src, dstd, &dst, [status, &n](const tensorflow::Status& s) {
|
||||
status->status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
|
||||
: nullptr;
|
||||
tensorflow::Tensor dst(dstd->GetAllocator(tensorflow::AllocatorAttributes()),
|
||||
src->dtype(), src->shape());
|
||||
if (src->shape().num_elements() == 0) {
|
||||
return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
|
||||
}
|
||||
tensorflow::DeviceContext* src_device_context = nullptr;
|
||||
if (!IsCPU(srcd)) {
|
||||
src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
|
||||
}
|
||||
tensorflow::DeviceContext* dst_device_context = nullptr;
|
||||
if (!dst_cpu) {
|
||||
dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
|
||||
}
|
||||
CHECK(dst_cpu);
|
||||
tensorflow::Tensor dst(src->dtype(), src->shape());
|
||||
tensorflow::Notification n;
|
||||
// TODO(ashankar): The Sync() call below may be more aggressive than
|
||||
// necessary. It is based on knowledge of implementation details - that
|
||||
// GPU devices are implemented using 3 streams - one for host->device copies,
|
||||
@ -217,16 +199,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
// but more than necessary (since it waits for operations that might have
|
||||
// nothing to do with this tensor to complete).
|
||||
status->status = srcd->Sync();
|
||||
if (!status->status.ok()) return nullptr;
|
||||
srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU(
|
||||
src, "IGNORE_MY_TENSOR_NAME", srcd, &dst,
|
||||
[status, &n](const tensorflow::Status& s) {
|
||||
status->status = s;
|
||||
n.Notify();
|
||||
});
|
||||
tensorflow::Notification n;
|
||||
tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
|
||||
srcd, dstd, tensorflow::AllocatorAttributes(),
|
||||
tensorflow::AllocatorAttributes(), src, &dst,
|
||||
[status, &n](const tensorflow::Status& s) {
|
||||
status->status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr)
|
||||
: nullptr;
|
||||
return (TF_GetCode(status) == TF_OK)
|
||||
? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
|
@ -216,6 +216,64 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
|
||||
const char* kCPUDevice = "CPU:0";
|
||||
if (num_devices < 3) {
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteContext(ctx, status.get());
|
||||
return;
|
||||
}
|
||||
const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
|
||||
const string gpu_2_name(TF_DeviceListName(devices, 2, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
|
||||
|
||||
TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice(
|
||||
hdevice, ctx, gpu_2_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
|
||||
TFE_DeleteTensorHandle(hdevice);
|
||||
// Copy back to CPU
|
||||
TFE_TensorHandle* hcopy =
|
||||
TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
|
||||
TFE_DeleteTensorHandle(hdevice2);
|
||||
|
||||
// Ensure that the contents are the same!
|
||||
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
|
||||
TFE_DeleteTensorHandle(hcopy);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
|
||||
EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy));
|
||||
EXPECT_EQ(
|
||||
0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)));
|
||||
TF_DeleteTensor(tcopy);
|
||||
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteContext(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -248,6 +248,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:const_analysis",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -48,6 +48,16 @@ typedef std::function<Status(
|
||||
// 'group_attribute' must be a string valued-attribute that names the new
|
||||
// functions to introduce.
|
||||
//
|
||||
// 'outside_compilation_attribute' must be a string-valued attribute that is
|
||||
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
|
||||
// cluster within the subgraph. A cluster is formed from the set of nodes with
|
||||
// the same value of outside_compilation_subgraph and group_attribute. The nodes
|
||||
// in an outside_compilation cluster are left in the original graph. Edges
|
||||
// crossing from the subgraph to an outside_compilation cluster nested in the
|
||||
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
|
||||
// crossing from an outside_compilation cluster into its enclosing subgraph are
|
||||
// lifted into a SendFromHost/RecvFromHost pair of nodes.
|
||||
//
|
||||
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
|
||||
// function conversion.
|
||||
//
|
||||
@ -64,10 +74,10 @@ typedef std::function<Status(
|
||||
// dep from B. Originally D must run after C, post-transformation this
|
||||
// dependency is lost.
|
||||
Status EncapsulateSubgraphsInFunctions(
|
||||
string group_attribute, const Graph& graph_in,
|
||||
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
|
||||
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
|
||||
FunctionLibraryDefinition* library);
|
||||
string group_attribute, string outside_compilation_attribute,
|
||||
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
|
||||
bool parallel_checking, bool reuse_existing_functions,
|
||||
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
|
||||
|
||||
// The attribute that marks function calls produced by the encapsulate
|
||||
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
|
||||
|
@ -36,7 +36,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Definition mismatch for function ",
|
||||
a.signature().name(), ", expected:\n",
|
||||
a.DebugString());
|
||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -82,6 +82,24 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
||||
<< diff << "\nActual: " << actual.DebugString(); \
|
||||
} while (false)
|
||||
|
||||
// TODO(misard): remove these fake registrations once there are real Ops to be
|
||||
// compiled.
|
||||
REGISTER_OP("_XlaSendToHost")
|
||||
.Input("input: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
|
||||
REGISTER_OP("_XlaRecvFromHost")
|
||||
.Output("output: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
|
||||
REGISTER_OP("_XlaSendFromHost")
|
||||
.Input("input: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
|
||||
REGISTER_OP("_XlaRecvAtHost")
|
||||
.Output("output: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
|
||||
REGISTER_OP("InputTest").Output("o: float");
|
||||
|
||||
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
|
||||
@ -98,10 +116,32 @@ REGISTER_OP("AddNLikeTest")
|
||||
.SetIsCommutative()
|
||||
.SetIsAggregate();
|
||||
|
||||
Node* NoOp(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("NoOp", opts);
|
||||
}
|
||||
|
||||
Node* Input(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("InputTest", opts);
|
||||
}
|
||||
|
||||
Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
|
||||
"_XlaRecvAtHost", opts.op_registry());
|
||||
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
|
||||
const gtl::ArraySlice<DataType>& dtypes,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
|
||||
"_XlaSendFromHost", opts.op_registry());
|
||||
node_builder.Input(inputs);
|
||||
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||
return ops::UnaryOp("UnaryTest", std::move(a), opts);
|
||||
}
|
||||
@ -145,7 +185,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
|
||||
if (!s.ok()) return s;
|
||||
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
|
||||
s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph,
|
||||
/*rewrite_subgraph_fn=*/{},
|
||||
/*parallel_checking=*/false,
|
||||
/*reuse_existing_functions=*/false,
|
||||
@ -178,6 +218,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) {
|
||||
FunctionDefLibrary library_out = library_in;
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
|
||||
|
||||
// If there are no marked nodes, funcification should be a no-op.
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
|
||||
}
|
||||
@ -230,7 +271,6 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
// If there are no marked nodes, funcification should be a no-op.
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
@ -342,9 +382,9 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
|
||||
/*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph,
|
||||
&library));
|
||||
"_cluster", "_outside", graph_before_encapsulation,
|
||||
/*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false,
|
||||
/*reuse_existing_functions=*/false, &graph, &library));
|
||||
|
||||
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
|
||||
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
|
||||
@ -374,9 +414,9 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
|
||||
/*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph,
|
||||
&library));
|
||||
"_cluster", "_outside", graph_before_encapsulation,
|
||||
/*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true,
|
||||
/*reuse_existing_functions=*/false, &graph, &library));
|
||||
|
||||
std::vector<string> expected_nodes = {
|
||||
"add1", "add2", "cluster1", "cluster1_parallel_check/_0",
|
||||
@ -432,7 +472,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
int guaranteed_consts = 0;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_encapsulate", graph_before,
|
||||
"_encapsulate", "_outside", graph_before,
|
||||
/*rewrite_subgraph_fn=*/
|
||||
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
|
||||
std::vector<int>* input_permutation,
|
||||
@ -477,7 +517,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
|
||||
FunctionLibraryDefinition library(OpRegistry::Global(), {});
|
||||
int guaranteed_consts = 0;
|
||||
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
|
||||
"_encapsulate", graph_before,
|
||||
"_encapsulate", "_outside", graph_before,
|
||||
/*rewrite_subgraph_fn=*/
|
||||
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
|
||||
std::vector<int>* input_permutation,
|
||||
@ -502,5 +542,678 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
|
||||
EXPECT_EQ(1, guaranteed_consts);
|
||||
}
|
||||
|
||||
// Test with one function to transform and one outside_compilation cluster.
|
||||
TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
*library.add_function() = test::function::XTimesTwo();
|
||||
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
// Give nodes 'c' and 'd' names that collide after lowercasing.
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d = Binary(b, c,
|
||||
b1.opts().WithName("c").WithControlInput(c).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Node* e = Binary(c, d,
|
||||
b1.opts()
|
||||
.WithName("E")
|
||||
.WithControlInputs({b, d})
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Binary(c, e,
|
||||
b1.opts().WithName("F").WithControlInput(e).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = test::function::XTimesTwo();
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"C:o:0", "c:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{"c"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
NodeBuilder node_builder("F1", "F1", lib_def.get());
|
||||
node_builder.Input(a).Input(b);
|
||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||
|
||||
Node* recv =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||
b2.opts().WithName("E").WithControlInputs({recv, b}));
|
||||
Node* send = SendFromHost({e}, {DT_FLOAT},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
|
||||
Node* s = NoOp(
|
||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
|
||||
|
||||
Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with one function to transform and two outside_compilation clusters.
|
||||
TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Binary(c, d,
|
||||
b1.opts()
|
||||
.WithName("E")
|
||||
.WithControlInputs({b, d})
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Binary(c, e,
|
||||
b1.opts().WithName("F").WithControlInput(e).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Node* g = Binary(e, f,
|
||||
b1.opts()
|
||||
.WithName("G")
|
||||
.WithControlInputs({e, f})
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O2"));
|
||||
Node* h = Binary(d, e,
|
||||
b1.opts()
|
||||
.WithName("H")
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O2"));
|
||||
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
|
||||
Binary(g, i, b1.opts().WithName("J"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
|
||||
{{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O2_send"},
|
||||
"_XlaSendToHost",
|
||||
{"D:o:0", "F:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{"F"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{"D"}},
|
||||
{{"outside_compilation_O2_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O2_send"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"i_0_retval", "I:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
NodeBuilder node_builder("F1", "F1", lib_def.get());
|
||||
node_builder.Input(a).Input(b);
|
||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||
|
||||
Node* recv1 =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
||||
Node* send1 = SendFromHost({e}, {DT_FLOAT},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
|
||||
Node* recv2 =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O2_recv"));
|
||||
Node* g = Binary(e, ops::NodeOut(recv2, 1),
|
||||
b2.opts().WithName("G").WithControlInputs({recv2, e}));
|
||||
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
|
||||
Node* send2 = SendFromHost(
|
||||
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
|
||||
|
||||
Node* s = NoOp(b2.opts()
|
||||
.WithName("F1_sequencer")
|
||||
.WithControlInputs({recv1, send1, recv2, send2}));
|
||||
|
||||
Binary(g, call, b2.opts().WithName("J").WithControlInput(s));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with two functions to transform, each with one outside_compilation
|
||||
// cluster.
|
||||
TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Binary(c, d,
|
||||
b1.opts()
|
||||
.WithName("E")
|
||||
.WithControlInputs({b, d})
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Binary(c, e,
|
||||
b1.opts().WithName("F").WithControlInput(e).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Node* g = Binary(e, f,
|
||||
b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
|
||||
"_encapsulate", "F2"));
|
||||
Node* h = Binary(d, g,
|
||||
b1.opts()
|
||||
.WithName("H")
|
||||
.WithAttr("_encapsulate", "F2")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* i =
|
||||
Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
|
||||
Binary(g, i, b1.opts().WithName("J"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"},
|
||||
{"f_0_retval:float", "d_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{"D"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F2", {"e_0_arg:float", "f_0_arg:float"},
|
||||
{"g_0_retval:float", "i_0_retval:float"}, {},
|
||||
{
|
||||
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
|
||||
{{"I"},
|
||||
"BinaryTest",
|
||||
{"f_0_arg", "outside_compilation_O1_recv:output:0"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"G:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
||||
Node* send1 = SendFromHost({e}, {DT_FLOAT},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
Node* s1 = NoOp(
|
||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
|
||||
|
||||
Node* recv2 = RecvAtHost(
|
||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
|
||||
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
|
||||
b2.opts().WithName("H").WithControlInput(s1));
|
||||
Node* send2 = SendFromHost(
|
||||
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
|
||||
|
||||
NodeBuilder node_builder2("F2", "F2", lib_def.get());
|
||||
node_builder2.Input(e).Input(call1);
|
||||
Node* call2 = b2.opts()
|
||||
.WithControlInputs({s1, e, call1})
|
||||
.FinalizeBuilder(&node_builder2);
|
||||
Node* s2 = NoOp(
|
||||
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}));
|
||||
Binary(call2, ops::NodeOut(call2, 1),
|
||||
b2.opts().WithName("J").WithControlInput(s2));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with one outside_compilation cluster that has no inputs from the
|
||||
// compiled subgraph.
|
||||
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Unary(a, b1.opts()
|
||||
.WithName("E")
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f =
|
||||
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
|
||||
Unary(f, b1.opts().WithName("G"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* e = Unary(a, b2.opts().WithName("E"));
|
||||
Node* send1 = SendFromHost(
|
||||
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1));
|
||||
|
||||
Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with one outside_compilation cluster that has no data inputs but has a
|
||||
// control input from the compiled subgraph.
|
||||
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Unary(a, b1.opts()
|
||||
.WithName("E")
|
||||
.WithControlInput(d)
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f =
|
||||
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
|
||||
Unary(f, b1.opts().WithName("G"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({})}},
|
||||
{"D"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 =
|
||||
RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
|
||||
Node* send1 = SendFromHost(
|
||||
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
Node* s1 = NoOp(
|
||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
|
||||
|
||||
Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with one outside_compilation cluster that has no outputs from the
|
||||
// compiled subgraph.
|
||||
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Unary(d, b1.opts()
|
||||
.WithName("E")
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
|
||||
Binary(e, f, b1.opts().WithName("G"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"}, "UnaryTest", {"D:o:0"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 = RecvAtHost(
|
||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1));
|
||||
|
||||
Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with one outside_compilation cluster that has no data outputs but has a
|
||||
// control output to the compiled subgraph.
|
||||
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Unary(d, b1.opts()
|
||||
.WithName("E")
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Binary(e, f, b1.opts().WithName("G"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 = RecvAtHost(
|
||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
||||
Node* send1 = SendFromHost({}, {},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
Node* s1 = NoOp(
|
||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
|
||||
|
||||
Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test with one outside_compilation cluster that has no outputs from the
|
||||
// compiled subgraph.
|
||||
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
Node* e = Unary(a, b1.opts()
|
||||
.WithName("E")
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
|
||||
Binary(e, f, b1.opts().WithName("G"));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"}, "UnaryTest", {"D:o:0"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* e = Unary(a, b2.opts().WithName("E"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
|
||||
Binary(e, call1, b2.opts().WithName("G"));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaClusterAttr = "_XlaCluster";
|
||||
const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -28,6 +28,10 @@ namespace tensorflow {
|
||||
// encapsulate subgraphs pass.
|
||||
extern const char* const kXlaClusterAttr;
|
||||
|
||||
// The attribute that marks nodes in a cluster to be placed outside the xla
|
||||
// compilation by the encapsulate subgraphs pass.
|
||||
extern const char* const kXlaOutsideCompilationAttr;
|
||||
|
||||
// Pass that marks a subset of operators in the graph with attribute
|
||||
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
|
||||
class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
|
@ -270,8 +270,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ tensorflow::gtl::ArraySlice<const int64>
|
||||
LayoutUtil::PaddedDimensions(const Shape& shape) {
|
||||
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
|
||||
const Shape& shape) {
|
||||
CHECK(IsDense(shape));
|
||||
return AsInt64Slice(shape.layout().padded_dimensions());
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ class LayoutUtil {
|
||||
|
||||
// Returns the padded_dimensions array for the given Shape. Requires that the
|
||||
// shape is an array and has a dense layout.
|
||||
static tensorflow::gtl::ArraySlice<const int64> PaddedDimensions(
|
||||
static tensorflow::gtl::ArraySlice<int64> PaddedDimensions(
|
||||
const Shape& shape);
|
||||
|
||||
// Returns the given index of the padded_dimensions array for the given Shape.
|
||||
|
@ -341,7 +341,7 @@ class Literal {
|
||||
|
||||
// Creates a literal of the given shape where each element is `value`.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFullWithMonotonicDim0MajorLayout(
|
||||
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
|
||||
|
||||
// Creates a new literal from an array. The variants not ending with
|
||||
@ -1233,10 +1233,9 @@ void Literal::PopulateWithValue(NativeT value,
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal>
|
||||
Literal::CreateFullWithMonotonicDim0MajorLayout(
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
|
||||
Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
Shape this_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
|
||||
auto literal = MakeUnique<Literal>();
|
||||
*literal->mutable_shape() = this_shape;
|
||||
|
@ -46,6 +46,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
|
@ -166,6 +166,18 @@ ComputationDataHandle LocalComputationBuilder::Dot(
|
||||
return builder_.Dot(lhs, rhs);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers) {
|
||||
return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation,
|
||||
dimension_numbers);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::ConvertElementType(
|
||||
const ComputationDataHandle& operand, PrimitiveType new_element_type) {
|
||||
return builder_.ConvertElementType(operand, new_element_type);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace xla {
|
||||
@ -113,6 +114,14 @@ class LocalComputationBuilder {
|
||||
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
|
||||
const ComputationDataHandle& rhs);
|
||||
|
||||
ComputationDataHandle ConvGeneralDilated(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers);
|
||||
|
||||
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
|
||||
PrimitiveType new_element_type);
|
||||
|
||||
|
@ -22,18 +22,19 @@ limitations under the License.
|
||||
//
|
||||
// C++ Python
|
||||
// -------------------------------------+---------------------------------------
|
||||
// ComputationDataHandle <-> long
|
||||
// ArraySlice<int64> <- sequence of long
|
||||
// ArraySlice<ComputationDataHandle> <- sequence of long
|
||||
// ComputationDataHandle <-> int
|
||||
// ArraySlice<int64> <- sequence of int
|
||||
// ArraySlice<ComputationDataHandle> <- sequence of int
|
||||
// Literal <-> (nested tuple of) numpy ndarray
|
||||
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
||||
// Shape <-> pair holding (dtype, dimensions)
|
||||
// std::vector<Shape> <- sequence of shape information pairs
|
||||
// PrimitiveType <- int
|
||||
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
|
||||
// ConvolutionDimensionNumbers proto <- corresponding Python proto
|
||||
//
|
||||
// Arrows indicate whether a conversion only ever occurs in one
|
||||
// direction, or whether it is maintained bidirectionally. Also,
|
||||
// "long" and "int" denote the Python types so named, not C.
|
||||
// direction, or whether it is maintained bidirectionally.
|
||||
//
|
||||
// The Python objects corresponding to C++ Literals have the type:
|
||||
//
|
||||
@ -113,6 +114,27 @@ limitations under the License.
|
||||
|
||||
using namespace xla;
|
||||
using namespace xla::swig;
|
||||
|
||||
namespace xla {
|
||||
namespace swig {
|
||||
|
||||
bool GetIntAttr(PyObject* o, const char* field, int64* result) {
|
||||
PyObject* fo = PyObject_GetAttrString(o, field);
|
||||
if (!fo) {
|
||||
return false;
|
||||
}
|
||||
const int64 value = numpy::PyIntOrPyLongToLong(fo);
|
||||
if (value == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(fo);
|
||||
return false;
|
||||
}
|
||||
Py_DECREF(fo);
|
||||
*result = value;
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
%}
|
||||
|
||||
// Required to use PyArray_* functions.
|
||||
@ -278,6 +300,189 @@ tensorflow::ImportNumpy();
|
||||
$1 = static_cast<PrimitiveType>(value);
|
||||
}
|
||||
|
||||
// ArraySlice<pair<int64, in64>>
|
||||
|
||||
%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
|
||||
(std::vector<std::pair<int64, int64> > temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
return NULL;
|
||||
}
|
||||
const int size = PySequence_Size($input);
|
||||
temps.reserve(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
if (!o) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject* first = PyTuple_GetItem(o, 0);
|
||||
if (!first) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
PyObject* first_pyint = numpy::PyNumberToPyInt(first);
|
||||
if (!first_pyint) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"First pair item cannot be converted to int");
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
PyObject* second = PyTuple_GetItem(o, 1);
|
||||
if (!second) {
|
||||
Py_DECREF(o);
|
||||
Py_DECREF(first_pyint);
|
||||
return NULL;
|
||||
}
|
||||
PyObject* second_pyint = numpy::PyNumberToPyInt(second);
|
||||
if (!second_pyint) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
"Second pair item cannot be converted to int");
|
||||
Py_DECREF(o);
|
||||
Py_DECREF(first_pyint);
|
||||
return NULL;
|
||||
}
|
||||
const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
|
||||
if (first_value == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(o);
|
||||
Py_DECREF(first_pyint);
|
||||
Py_DECREF(second_pyint);
|
||||
return NULL;
|
||||
}
|
||||
const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
|
||||
if (second_value == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(o);
|
||||
Py_DECREF(first_pyint);
|
||||
Py_DECREF(second_pyint);
|
||||
return NULL;
|
||||
}
|
||||
temps.push_back(std::make_pair(first_value, second_value));
|
||||
Py_DECREF(o);
|
||||
}
|
||||
$1 = temps;
|
||||
}
|
||||
|
||||
// ConvolutionDimensionNumbers
|
||||
|
||||
%typemap(in) const ConvolutionDimensionNumbers&
|
||||
(ConvolutionDimensionNumbers dimension_numbers) {
|
||||
int64 value;
|
||||
|
||||
if (!GetIntAttr($input, "input_batch_dimension", &value)) {
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.set_input_batch_dimension(value);
|
||||
|
||||
if (!GetIntAttr($input, "input_feature_dimension", &value)) {
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.set_input_feature_dimension(value);
|
||||
|
||||
if (!GetIntAttr($input, "output_batch_dimension", &value)) {
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.set_output_batch_dimension(value);
|
||||
|
||||
if (!GetIntAttr($input, "output_feature_dimension", &value)) {
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.set_output_feature_dimension(value);
|
||||
|
||||
if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.set_kernel_output_feature_dimension(value);
|
||||
|
||||
if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.set_kernel_input_feature_dimension(value);
|
||||
|
||||
PyObject* o;
|
||||
int length;
|
||||
|
||||
o = PyObject_GetAttrString($input, "input_spatial_dimensions");
|
||||
if (!o) {
|
||||
return NULL;
|
||||
}
|
||||
length = PySequence_Size(o);
|
||||
if (length == -1) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(o, i);
|
||||
if (!item) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_input_spatial_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(o);
|
||||
|
||||
o = PyObject_GetAttrString($input, "kernel_spatial_dimensions");
|
||||
if (!o) {
|
||||
return NULL;
|
||||
}
|
||||
length = PySequence_Size(o);
|
||||
if (length == -1) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(o, i);
|
||||
if (!item) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_kernel_spatial_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(o);
|
||||
|
||||
o = PyObject_GetAttrString($input, "output_spatial_dimensions");
|
||||
if (!o) {
|
||||
return NULL;
|
||||
}
|
||||
length = PySequence_Size(o);
|
||||
if (length == -1) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(o, i);
|
||||
if (!item) {
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(o);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_output_spatial_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(o);
|
||||
|
||||
$1 = &dimension_numbers;
|
||||
}
|
||||
|
||||
%ignoreall
|
||||
%unignore xla;
|
||||
%unignore xla::swig;
|
||||
@ -314,6 +519,7 @@ tensorflow::ImportNumpy();
|
||||
%unignore xla::swig::LocalComputationBuilder::Lt;
|
||||
%unignore xla::swig::LocalComputationBuilder::Le;
|
||||
%unignore xla::swig::LocalComputationBuilder::Dot;
|
||||
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
|
||||
%unignore xla::swig::LocalComputationBuilder::Add;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sub;
|
||||
%unignore xla::swig::LocalComputationBuilder::Mul;
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
@ -25,6 +26,12 @@ import numpy as np
|
||||
from tensorflow.compiler.xla import xla_data_pb2
|
||||
from tensorflow.compiler.xla.python import pywrap_xla as c_api
|
||||
|
||||
|
||||
class PaddingType(enum.Enum):
|
||||
VALID = 1
|
||||
SAME = 2
|
||||
|
||||
|
||||
_UNARY_OPS = [
|
||||
'Not',
|
||||
'Abs',
|
||||
@ -564,6 +571,79 @@ class ComputationBuilder(object):
|
||||
return _wrap_data_handle(
|
||||
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
|
||||
|
||||
def Conv(self, lhs, rhs, window_strides, padding):
|
||||
"""Enqueues a Conv operation onto the computation.
|
||||
|
||||
Args:
|
||||
lhs: ComputationDataHandle for the rank N+2 array of inputs.
|
||||
rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
|
||||
window_strides: length-N array-like of integer kernel strides.
|
||||
padding: PaddingType representing either 'SAME' or 'VALID' padding.
|
||||
|
||||
Returns: a ComputationDataHandle representing the Conv operation.
|
||||
"""
|
||||
if padding == PaddingType.SAME:
|
||||
lhs_dims = self.GetShape(lhs).dimensions()
|
||||
rhs_dims = self.GetShape(rhs).dimensions()
|
||||
in_shape, filter_shape = lhs_dims[2:], rhs_dims[2:]
|
||||
out_shape = np.ceil(np.true_divide(in_shape, window_strides)).astype(int)
|
||||
pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0)
|
||||
for out_size, stride, filter_size, in_size
|
||||
in zip(out_shape, window_strides, filter_shape, in_shape)]
|
||||
pads = [(pad_size // 2, pad_size - pad_size // 2)
|
||||
for pad_size in pad_sizes]
|
||||
else:
|
||||
pads = [(0, 0)] * len(window_strides)
|
||||
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
|
||||
return _wrap_data_handle(
|
||||
self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
|
||||
_unwrap_data_handle(rhs),
|
||||
window_strides,
|
||||
pads,
|
||||
(),
|
||||
(),
|
||||
dimension_numbers))
|
||||
|
||||
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation):
|
||||
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
|
||||
|
||||
Args:
|
||||
lhs: ComputationDataHandle for the rank N+2 array of inputs.
|
||||
rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
|
||||
window_strides: length-N array-like of kernel strides.
|
||||
padding: length-N array-like of pairs of integers of (low, high) padding.
|
||||
lhs_dilation: length-N array-like of dilation factors.
|
||||
rhs_dilation: length-N array-like of dilation factors.
|
||||
|
||||
Returns:
|
||||
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
|
||||
"""
|
||||
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
|
||||
return _wrap_data_handle(
|
||||
self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
|
||||
_unwrap_data_handle(rhs),
|
||||
window_strides,
|
||||
padding,
|
||||
lhs_dilation,
|
||||
rhs_dilation,
|
||||
dimension_numbers))
|
||||
|
||||
def _GetConvDimensionNumbers(self, num_spatial_dims):
|
||||
"""Create ConvolutionDimensionNumbers proto for convolutions."""
|
||||
nd = num_spatial_dims
|
||||
dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
|
||||
dimension_numbers.input_batch_dimension = 0
|
||||
dimension_numbers.input_feature_dimension = 1
|
||||
dimension_numbers.output_batch_dimension = 0
|
||||
dimension_numbers.output_feature_dimension = 1
|
||||
dimension_numbers.kernel_output_feature_dimension = 0
|
||||
dimension_numbers.kernel_input_feature_dimension = 1
|
||||
dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
|
||||
dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
|
||||
dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
|
||||
return dimension_numbers
|
||||
|
||||
|
||||
def _forward_methods_to_local_builder():
|
||||
"""Forward remaining ComputationBuilder methods to the C API.
|
||||
|
@ -386,6 +386,46 @@ class SingleOpTest(LocalComputationTest):
|
||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||
|
||||
def testConvF32Same(self):
|
||||
c = self._NewComputation()
|
||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||
lhs = a(1, 2, 3, 4)
|
||||
rhs = a(1, 2, 1, 2) * 10
|
||||
c.Conv(c.Constant(lhs), c.Constant(rhs),
|
||||
[1, 1], xla_client.PaddingType.SAME)
|
||||
result = np.array([[[[640., 700., 760., 300.],
|
||||
[880., 940., 1000., 380.],
|
||||
[1120., 1180., 1240., 460.]]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testConvF32Valid(self):
|
||||
c = self._NewComputation()
|
||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||
lhs = a(1, 2, 3, 4)
|
||||
rhs = a(1, 2, 1, 2) * 10
|
||||
c.Conv(c.Constant(lhs), c.Constant(rhs),
|
||||
[2, 1], xla_client.PaddingType.VALID)
|
||||
result = np.array([[[[640., 700., 760.],
|
||||
[1120., 1180., 1240.]]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testConvWithGeneralPaddingF32(self):
|
||||
c = self._NewComputation()
|
||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||
lhs = a(1, 1, 2, 3)
|
||||
rhs = a(1, 1, 1, 2) * 10
|
||||
strides = [1, 1]
|
||||
pads = [(1, 0), (0, 1)]
|
||||
lhs_dilation = (2, 1)
|
||||
rhs_dilation = (1, 1)
|
||||
c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
|
||||
strides, pads, lhs_dilation, rhs_dilation)
|
||||
result = np.array([[[[0., 0., 0.],
|
||||
[10., 20., 0.],
|
||||
[0., 0., 0.],
|
||||
[40., 50., 0.]]]])
|
||||
self._ExecuteAndCompareClose(c, expected=result)
|
||||
|
||||
def testBooleanNot(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayBool([True, False, True])
|
||||
|
@ -1182,6 +1182,11 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
|
||||
if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
|
||||
return ReplaceWithNewInstruction(
|
||||
pad, HloInstruction::CreateBroadcast(pad->shape(),
|
||||
pad->mutable_operand(1), {}));
|
||||
}
|
||||
// Eliminate nop pads (padding all zero), and replace a pad with negative
|
||||
// padding with a pad with non-negative padding followed by a slice.
|
||||
bool all_zero = true;
|
||||
@ -1624,6 +1629,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
HloInstruction* reduce_window) {
|
||||
if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
|
||||
return ReplaceWithNewInstruction(
|
||||
reduce_window,
|
||||
HloInstruction::CreateBroadcast(reduce_window->shape(),
|
||||
reduce_window->mutable_operand(1), {}));
|
||||
}
|
||||
auto operand = reduce_window->mutable_operand(0);
|
||||
const Window& window = reduce_window->window();
|
||||
auto function = reduce_window->to_apply();
|
||||
@ -1694,7 +1705,6 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
|
||||
auto operand = transpose->mutable_operand(0);
|
||||
|
||||
if (std::is_sorted(transpose->dimensions().begin(),
|
||||
transpose->dimensions().end())) {
|
||||
VLOG(10) << "deleting no-op transpose";
|
||||
@ -1721,6 +1731,18 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
HloInstruction* convolution) {
|
||||
auto lhs = convolution->mutable_operand(0);
|
||||
auto rhs = convolution->mutable_operand(1);
|
||||
if (ShapeUtil::HasZeroElements(lhs->shape()) ||
|
||||
ShapeUtil::HasZeroElements(rhs->shape())) {
|
||||
return ReplaceWithNewInstruction(
|
||||
convolution,
|
||||
HloInstruction::CreateBroadcast(
|
||||
convolution->shape(),
|
||||
computation_->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
|
||||
computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
|
||||
{}));
|
||||
}
|
||||
const auto& window = convolution->window();
|
||||
if (!enable_conv_simplification_) {
|
||||
return Status::OK();
|
||||
@ -1813,18 +1835,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
|
||||
// We already checked feature_dimension is most minor, so data in input_shape
|
||||
// and row-major {conv_width,input_channels} are bitwise identical.
|
||||
const Shape new_input_shape =
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
input_shape.element_type(), {conv_width, input_channels});
|
||||
const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
input_shape.element_type(), {conv_width, input_channels});
|
||||
// We already checked input_feature_dimension is more major than
|
||||
// output_feature_dimension, so data in filter_shape and row-major
|
||||
// {input_channels,output_channels} are bitwise identical.
|
||||
const Shape new_filter_shape =
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
filter_shape.element_type(), {input_channels, output_channels});
|
||||
const Shape dot_output_shape =
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
convolution_shape.element_type(), {conv_width, output_channels});
|
||||
const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
filter_shape.element_type(), {input_channels, output_channels});
|
||||
const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
convolution_shape.element_type(), {conv_width, output_channels});
|
||||
|
||||
// We cannot insert bitcasts if the layouts will not be compatible.
|
||||
// TODO(b/33178038): Consider inserting a transpose if a bitcast would be
|
||||
|
@ -816,6 +816,120 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
1);
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
|
||||
|
||||
HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
|
||||
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.set_input_feature_dimension(2);
|
||||
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.set_output_feature_dimension(2);
|
||||
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.set_kernel_input_feature_dimension(1);
|
||||
dnums.set_kernel_output_feature_dimension(2);
|
||||
Window window;
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_size(3);
|
||||
dim->set_padding_low(0);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_stride(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_base_dilation(1);
|
||||
dim->set_window_reversal(false);
|
||||
// Create add computation.
|
||||
std::unique_ptr<HloModule> module = CreateNewModule();
|
||||
builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
|
||||
non_bitcasting_callback());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Convolution(lhs, rhs));
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Broadcast(op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
|
||||
Window window;
|
||||
for (int64 i = 0; i < 2; ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_size(1);
|
||||
dim->set_padding_low(1);
|
||||
dim->set_padding_high(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_base_dilation(1);
|
||||
}
|
||||
// Create add computation.
|
||||
std::unique_ptr<HloModule> module = CreateNewModule();
|
||||
HloComputation* add_computation = nullptr;
|
||||
{
|
||||
HloComputation::Builder builder(TestName() + ".add");
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloInstruction* p0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, scalar_shape, "p0"));
|
||||
HloInstruction* p1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
|
||||
add_computation = module->AddEmbeddedComputation(builder.Build());
|
||||
}
|
||||
builder.AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
ShapeUtil::MakeShape(F32, {5, 2}), param,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
|
||||
window, add_computation));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
|
||||
non_bitcasting_callback());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::ReduceWindow(param, op::Constant()));
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Broadcast(op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* param =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
|
||||
PaddingConfig padding;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
|
||||
dimension->set_edge_padding_low(1);
|
||||
dimension->set_edge_padding_high(1);
|
||||
dimension->set_interior_padding(0);
|
||||
}
|
||||
builder.AddInstruction(HloInstruction::CreatePad(
|
||||
ShapeUtil::MakeShape(F32, {5, 2}), param,
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
|
||||
padding));
|
||||
std::unique_ptr<HloModule> module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Pad(param, op::Constant()));
|
||||
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
|
||||
non_bitcasting_callback());
|
||||
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Broadcast(op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
|
||||
@ -1309,7 +1423,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param0 =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}),
|
||||
0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
|
||||
"param0"));
|
||||
|
||||
HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
|
@ -28,8 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<GlobalDataHandle> AllocationTracker::Register(
|
||||
|
@ -96,6 +96,26 @@ class ConvolutionThunk : public Thunk {
|
||||
return !best_algorithm_.has_value();
|
||||
}
|
||||
|
||||
// Return true if scratch memory is needed to execute the thunk, that is
|
||||
// either the best algorithm hasn't been chosen or the best algorithm is not
|
||||
// the same as the no-scratch algorithm. This is because that the execution
|
||||
// of the thunk is asynchronous, and the scratch allocator goes out of
|
||||
// scope before the thunk finishes execution. Returning true tells the stream
|
||||
// executor to make future thunks wait for this thunk to avoid reusing the
|
||||
// deallocated scratch memory until this thunk is done with it.
|
||||
bool ShouldBlockFutureThunks() {
|
||||
if (!best_algorithm_.has_value()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const perftools::gputools::dnn::AlgorithmDesc& best_alg =
|
||||
best_algorithm_->algorithm();
|
||||
const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg =
|
||||
best_algorithm_->algorithm_no_scratch();
|
||||
return (!best_alg.is_default() || !no_scratch_best_alg.is_default() ||
|
||||
!(best_alg == no_scratch_best_alg));
|
||||
}
|
||||
|
||||
private:
|
||||
tensorflow::Status ConvolveWithTune(
|
||||
const perftools::gputools::dnn::BatchDescriptor& input_descriptor,
|
||||
|
@ -155,6 +155,7 @@ Status GpuExecutable::ExecuteThunks(
|
||||
run_options->BorrowStream(main_stream->parent()->device_ordinal()));
|
||||
}
|
||||
|
||||
std::map<int32, const Thunk*> last_blocking_thunk_for_stream;
|
||||
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
|
||||
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
|
||||
TF_RETURN_IF_ERROR(thunk->Initialize(*this));
|
||||
@ -167,10 +168,17 @@ Status GpuExecutable::ExecuteThunks(
|
||||
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
|
||||
}
|
||||
|
||||
if (last_blocking_thunk_for_stream.count(stream_no)) {
|
||||
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event,
|
||||
last_blocking_thunk_for_stream[stream_no])
|
||||
.get());
|
||||
}
|
||||
|
||||
// If this thunk requests it, wait for all currently-executing thunks to
|
||||
// finish. This is useful e.g. if the thunk is about to perform autotuning.
|
||||
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
|
||||
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
|
||||
last_blocking_thunk_for_stream.clear();
|
||||
}
|
||||
|
||||
profiler.StartOperation();
|
||||
@ -178,11 +186,14 @@ Status GpuExecutable::ExecuteThunks(
|
||||
<< thunk->hlo_instruction()->ToString() << " on stream "
|
||||
<< stream_no;
|
||||
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
|
||||
if (thunk_schedule_->Depended(thunk)) {
|
||||
if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) {
|
||||
auto finish_event = MakeUnique<se::Event>(main_stream->parent());
|
||||
finish_event->Init();
|
||||
stream->ThenRecordEvent(finish_event.get());
|
||||
thunk_to_finish_event[thunk] = std::move(finish_event);
|
||||
if (thunk->ShouldBlockFutureThunks()) {
|
||||
last_blocking_thunk_for_stream[stream_no] = thunk;
|
||||
}
|
||||
}
|
||||
profiler.FinishOperation(thunk->hlo_instruction());
|
||||
}
|
||||
|
@ -419,8 +419,8 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
|
||||
(segs.size() == i ? shape.dimensions().size() : segs[i]),
|
||||
1, std::multiplies<int64>()));
|
||||
}
|
||||
return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(),
|
||||
dimensions);
|
||||
return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
|
||||
dimensions);
|
||||
}
|
||||
|
||||
// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
|
||||
@ -442,11 +442,13 @@ std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
|
||||
}
|
||||
}
|
||||
auto segs = ConsecutiveSegments(perm);
|
||||
Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a);
|
||||
Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b);
|
||||
Shape norm_a =
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
|
||||
Shape norm_b =
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
|
||||
if (3 == segs.size() && 0 == perm[0]) {
|
||||
Shape reduced_a = MergeDimensions(segs, norm_a);
|
||||
Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
b.element_type(),
|
||||
Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
|
||||
return std::make_tuple(true, reduced_a, reduced_b);
|
||||
@ -460,10 +462,11 @@ std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
|
||||
bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
|
||||
return 3 == b.dimensions().size() &&
|
||||
ShapeUtil::Compatible(
|
||||
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a),
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
|
||||
ShapeUtil::PermuteDimensions(
|
||||
{0, 2, 1},
|
||||
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b)));
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
b)));
|
||||
}
|
||||
|
||||
// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
|
||||
@ -495,9 +498,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
|
||||
CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
|
||||
|
||||
Shape input_shape =
|
||||
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape());
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
input.GetShape());
|
||||
Shape output_shape =
|
||||
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape());
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
output.GetShape());
|
||||
input = input.CastToShape(input_shape, builder);
|
||||
output = output.CastToShape(output_shape, builder);
|
||||
|
||||
@ -615,7 +620,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
|
||||
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
|
||||
builder))),
|
||||
builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
|
||||
builder);
|
||||
const llvm_ir::IrArray::Index input_tile_origin = ({
|
||||
@ -811,14 +816,15 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
// input_shape to normalized_input_shape and a reshape from
|
||||
// normalized_input_shape to input_matrix_shape.
|
||||
const Shape normalized_input_shape =
|
||||
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape);
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
input_shape);
|
||||
auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
|
||||
const std::vector<int64> transpose_dimension_mapping(
|
||||
input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
|
||||
|
||||
const Shape input_matrix_shape =
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
input_shape.element_type(), {height, width});
|
||||
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
|
||||
{height, width});
|
||||
const llvm_ir::IrArray::Index input_matrix_index(
|
||||
{y, x}, input_matrix_shape, &ir_builder_);
|
||||
const llvm_ir::IrArray::Index input_index =
|
||||
@ -1054,13 +1060,14 @@ Status IrEmitterUnnested::EmitRowReduction(
|
||||
// from input_shape to normalized_input_shape and a reshape from
|
||||
// normalized_input_shape to input_3d_tensor_shape.
|
||||
const Shape normalized_input_shape =
|
||||
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape);
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
input_shape);
|
||||
auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
|
||||
const std::vector<int64> transpose_dimension_mapping(
|
||||
input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
|
||||
const Shape input_3d_tensor_shape =
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
input_shape.element_type(), {depth, height, width});
|
||||
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
|
||||
{depth, height, width});
|
||||
const llvm_ir::IrArray::Index input_3d_tensor_index(
|
||||
{z, y, x}, input_3d_tensor_shape, &ir_builder_);
|
||||
const llvm_ir::IrArray::Index input_index =
|
||||
|
@ -83,6 +83,16 @@ class Thunk {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Indicates whether thunks scheduled after this one should wait for this one
|
||||
// to complete before running. For example, a convolution thunk creates a
|
||||
// scratch allocator, then kicks off a convolution in cudnn via the stream
|
||||
// executor. When the stream executor call returns, the scratch allocator goes
|
||||
// out of scope, and the scratch memory is deallocated. In this case, the
|
||||
// convolution thunk needs to return true so that future thunks wait for the
|
||||
// convolution thunk to avoid reusing the deallocated memory until the
|
||||
// convolution thunk is done with it.
|
||||
virtual bool ShouldBlockFutureThunks() { return false; }
|
||||
|
||||
// Execute the kernel for the thunk on the given stream. This method must be
|
||||
// called after Initialize and can be called multiple times over Thunk's
|
||||
// lifetime. Stream argument must be non-null.
|
||||
|
@ -1361,7 +1361,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
|
||||
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
|
||||
std::vector<int64> input_dims(6, 4);
|
||||
std::unique_ptr<Literal> arg_literal =
|
||||
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f);
|
||||
Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
|
||||
|
||||
HloInstruction* arg_instruction =
|
||||
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
|
||||
@ -1414,7 +1414,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
|
||||
|
||||
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
|
||||
std::unique_ptr<Literal> result_literal =
|
||||
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 8.0f);
|
||||
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
|
||||
LiteralTestUtil::ExpectEqual(*result_literal, *result);
|
||||
}
|
||||
|
||||
|
@ -478,7 +478,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
|
||||
// Add constraints for kCustomCall instruction operands and instructions.
|
||||
// For now we only support major-first layouts for all inputs and outputs.
|
||||
Shape result_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
instruction->shape().element_type(),
|
||||
AsInt64Slice(instruction->shape().dimensions()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -491,7 +491,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
}
|
||||
|
||||
Shape row_major_operand_shape =
|
||||
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
operand_shape.element_type(),
|
||||
AsInt64Slice(operand_shape.dimensions()));
|
||||
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
|
||||
|
@ -189,20 +189,21 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
|
||||
std::vector<int64> layout(dimensions.size());
|
||||
std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
|
||||
return MakeShapeWithLayout(element_type, dimensions, layout);
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(
|
||||
/* static */ Shape
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
const Shape& shape) {
|
||||
std::vector<int64> dims(shape.dimensions_size());
|
||||
for (int i = 0; i < shape.dimensions_size(); ++i) {
|
||||
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
|
||||
}
|
||||
return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims);
|
||||
return MakeShapeWithDescendingLayout(shape.element_type(), dims);
|
||||
}
|
||||
|
||||
/* static */ void ShapeUtil::PopulateShape(
|
||||
@ -1148,9 +1149,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
// as input_shape/output_shape and the dimension-0-major layout. These two
|
||||
// shapes are used for conversion between logical linear indices and
|
||||
// multi-dimensional indices.
|
||||
Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout(
|
||||
Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
|
||||
input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
|
||||
Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout(
|
||||
Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
|
||||
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
|
||||
|
||||
for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) {
|
||||
|
@ -268,14 +268,18 @@ class ShapeUtil {
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major);
|
||||
|
||||
// Constructs a new shape with major-first layout.
|
||||
static Shape MakeShapeWithMonotonicDim0MajorLayout(
|
||||
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
|
||||
static Shape MakeShapeWithDescendingLayout(
|
||||
PrimitiveType element_type,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions);
|
||||
|
||||
// Returns a new shape with major-first layout that has the same layout of
|
||||
// elements with a different shape.
|
||||
static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape);
|
||||
// Returns a new Shape based on the given Shape with low-dimension-major
|
||||
// layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
|
||||
// rearranged so that it has the same in-memory layout as the given shape.
|
||||
//
|
||||
// For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
|
||||
static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
const Shape& shape);
|
||||
|
||||
// As MakeShape, but the object to write to is passed in.
|
||||
static void PopulateShape(PrimitiveType element_type,
|
||||
|
@ -393,7 +393,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
|
||||
auto shape = ShapeUtil::MakeShape(F32, input_dims);
|
||||
|
||||
std::unique_ptr<Literal> arg_literal =
|
||||
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f);
|
||||
Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
|
||||
|
||||
const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
|
||||
|
||||
@ -402,7 +402,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
|
||||
|
||||
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
|
||||
std::unique_ptr<Literal> expected =
|
||||
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 9.0f);
|
||||
Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
|
||||
|
||||
ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
|
||||
}
|
||||
|
@ -397,6 +397,7 @@ py_test(
|
||||
srcs = ["scan_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dataset_serialization_test",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
@ -491,6 +492,25 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "unique_dataset_op_test",
|
||||
size = "small",
|
||||
srcs = ["unique_dataset_op_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":dataset_serialization_test",
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
"//tensorflow/contrib/stateless",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "zip_dataset_op_test",
|
||||
size = "small",
|
||||
|
@ -735,6 +735,20 @@ class BatchDatasetSerializationTest(
|
||||
lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
|
||||
num_outputs)
|
||||
|
||||
def _build_dataset_dense_to_sparse(self, components):
|
||||
return dataset_ops.Dataset.from_tensor_slices(components).map(
|
||||
lambda x: array_ops.fill([x], x)).apply(
|
||||
batching.dense_to_sparse_batch(4, [12]))
|
||||
|
||||
def testDenseToSparseBatchDatasetCore(self):
|
||||
components = np.random.randint(5, size=(40,)).astype(np.int32)
|
||||
diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
|
||||
|
||||
num_outputs = len(components) // 4
|
||||
self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
|
||||
lambda: self._build_dataset_dense_to_sparse(diff_comp),
|
||||
num_outputs)
|
||||
|
||||
|
||||
class PaddedBatchDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
@ -21,6 +21,7 @@ import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
|
||||
from tensorflow.contrib.data.python.ops import scan_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase):
|
||||
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
|
||||
|
||||
|
||||
class ScanDatasetSerialzationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def _build_dataset(self, num_elements):
|
||||
return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
|
||||
scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
|
||||
|
||||
def testScanCore(self):
|
||||
num_output = 5
|
||||
self.run_core_tests(lambda: self._build_dataset(num_output),
|
||||
lambda: self._build_dataset(2), num_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -0,0 +1,96 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the experimental input pipeline ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
|
||||
from tensorflow.contrib.data.python.ops import dataset_ops
|
||||
from tensorflow.contrib.data.python.ops import unique
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
class UniqueDatasetTest(test.TestCase):
|
||||
|
||||
def _testSimpleHelper(self, dtype, test_cases):
|
||||
"""Test the `unique()` transformation on a list of test cases.
|
||||
|
||||
Args:
|
||||
dtype: The `dtype` of the elements in each test case.
|
||||
test_cases: A list of pairs of lists. The first component is the test
|
||||
input that will be passed to the transformation; the second component
|
||||
is the expected sequence of outputs from the transformation.
|
||||
"""
|
||||
|
||||
# The `current_test_case` will be updated when we loop over `test_cases`
|
||||
# below; declare it here so that the generator can capture it once.
|
||||
current_test_case = []
|
||||
dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
|
||||
dtype).apply(unique.unique())
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
for test_case, expected in test_cases:
|
||||
current_test_case = test_case
|
||||
sess.run(iterator.initializer)
|
||||
for element in expected:
|
||||
if dtype == dtypes.string:
|
||||
element = compat.as_bytes(element)
|
||||
self.assertAllEqual(element, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testSimpleInt(self):
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
self._testSimpleHelper(dtype, [
|
||||
([], []),
|
||||
([1], [1]),
|
||||
([1, 1, 1, 1, 1, 1, 1], [1]),
|
||||
([1, 2, 3, 4], [1, 2, 3, 4]),
|
||||
([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
|
||||
([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
|
||||
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
|
||||
])
|
||||
|
||||
def testSimpleString(self):
|
||||
self._testSimpleHelper(dtypes.string, [
|
||||
([], []),
|
||||
(["hello"], ["hello"]),
|
||||
(["hello", "hello", "hello"], ["hello"]),
|
||||
(["hello", "world"], ["hello", "world"]),
|
||||
(["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
|
||||
])
|
||||
|
||||
|
||||
class UniqueSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def testUnique(self):
|
||||
|
||||
def build_dataset(num_elements, unique_elem_range):
|
||||
return dataset_ops.Dataset.range(num_elements).map(
|
||||
lambda x: x % unique_elem_range).apply(unique.unique())
|
||||
|
||||
self.run_core_tests(lambda: build_dataset(200, 100),
|
||||
lambda: build_dataset(40, 100), 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -105,6 +105,7 @@ py_library(
|
||||
"resampling.py",
|
||||
"scan_ops.py",
|
||||
"stats_ops.py",
|
||||
"unique.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
82
tensorflow/contrib/data/python/ops/unique.py
Normal file
82
tensorflow/contrib/data/python/ops/unique.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unique element dataset transformations."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
||||
def unique():
|
||||
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
|
||||
|
||||
Use this transformation to produce a dataset that contains one instance of
|
||||
each unique element in the input. For example:
|
||||
|
||||
```python
|
||||
dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
|
||||
|
||||
# Using `unique()` will drop the duplicate elements.
|
||||
dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 }
|
||||
```
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return UniqueDataset(dataset)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
||||
class UniqueDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` contains the unique elements from its input."""
|
||||
|
||||
def __init__(self, input_dataset):
|
||||
"""See `unique()` for details."""
|
||||
super(UniqueDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
|
||||
dtypes.string):
|
||||
raise TypeError(
|
||||
"`tf.contrib.data.unique()` only supports inputs with a single "
|
||||
"`tf.int32`, `tf.int64`, or `tf.string` component.")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.unique_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
output_shapes=nest.flatten(
|
||||
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
|
||||
output_types=nest.flatten(
|
||||
sparse.as_dense_types(self.output_types, self.output_classes)))
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._input_dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
@ -110,6 +110,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/kfac/python/ops:utils",
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
|
@ -22,11 +22,14 @@ import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
from tensorflow.contrib.kfac.python.ops import utils
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -95,6 +98,18 @@ class SubGraphTest(test.TestCase):
|
||||
filtered_list = sub_graph.filter_list(input_list)
|
||||
self.assertEqual(filtered_list, [b])
|
||||
|
||||
def testVariableUses(self):
|
||||
with ops.Graph().as_default():
|
||||
var = variable_scope.get_variable('var', shape=[10, 10])
|
||||
resource_var = variable_scope.get_variable(
|
||||
'resource_var', shape=[10, 10], use_resource=True)
|
||||
x = array_ops.zeros([3, 10])
|
||||
z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
|
||||
z1 = math_ops.matmul(x, resource_var)
|
||||
sub_graph = utils.SubGraph((z0, z1))
|
||||
self.assertEqual(2, sub_graph.variable_uses(var))
|
||||
self.assertEqual(1, sub_graph.variable_uses(resource_var))
|
||||
|
||||
|
||||
class UtilsTest(test.TestCase):
|
||||
|
||||
@ -253,6 +268,25 @@ class UtilsTest(test.TestCase):
|
||||
np_inv = np.linalg.inv(x + damp * np.eye(size))
|
||||
self.assertAllClose(sess.run(tf_inv), np_inv)
|
||||
|
||||
def testCrossReplicaMean(self):
|
||||
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
|
||||
with ops.Graph().as_default():
|
||||
with tpu_function.tpu_shard_context(4):
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
self.assertNotEqual(mean, tensor)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
self.assertEqual(mean, tensor)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaises(ValueError): # Outside of TPU context.
|
||||
tensor = array_ops.zeros([], dtype=dtypes.float32)
|
||||
mean = utils.cross_replica_mean(tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -196,6 +196,7 @@ py_library(
|
||||
srcs = ["utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
|
@ -267,6 +267,10 @@ class FisherFactor(object):
|
||||
new_cov = math_ops.add_n(
|
||||
tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
|
||||
|
||||
# Synchronize value across all TPU cores.
|
||||
if utils.on_tpu():
|
||||
new_cov = utils.cross_replica_mean(new_cov)
|
||||
|
||||
return moving_averages.assign_moving_average(
|
||||
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
|
||||
|
||||
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||
from tensorflow.contrib.tpu.python.tpu import tpu_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -27,6 +29,8 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
# Method used for inverting matrices.
|
||||
POSDEF_INV_METHOD = "cholesky"
|
||||
@ -226,11 +230,13 @@ class SubGraph(object):
|
||||
"""
|
||||
|
||||
def __init__(self, outputs):
|
||||
# Set of all ancestor Tensors, Ops to 'outputs'.
|
||||
self._members = set()
|
||||
|
||||
self._recurse_add(outputs)
|
||||
|
||||
def _recurse_add(self, nodes):
|
||||
"""Recursively adds all of nodes' ancestors."""
|
||||
for node in nodes:
|
||||
if node in self._members:
|
||||
continue
|
||||
@ -246,8 +252,25 @@ class SubGraph(object):
|
||||
return node in self._members
|
||||
|
||||
def variable_uses(self, var):
|
||||
"""Computes number of times a variable is used."""
|
||||
return len(self._members.intersection(set(var.value().consumers())))
|
||||
"""Computes number of times a variable is used.
|
||||
|
||||
Args:
|
||||
var: Variable or ResourceVariable instance.
|
||||
|
||||
Returns:
|
||||
Number of times a variable is used within this subgraph.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'var' is not a variable type.
|
||||
"""
|
||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
var = var.handle
|
||||
elif isinstance(var, variables.Variable):
|
||||
var = var.value()
|
||||
else:
|
||||
raise ValueError("%s does not appear to be a variable." % str(var))
|
||||
|
||||
return len(self._members.intersection(set(var.consumers())))
|
||||
|
||||
def filter_list(self, node_list):
|
||||
"""Filters 'node_list' to nodes in this subgraph."""
|
||||
@ -292,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
|
||||
|
||||
return dysdx
|
||||
|
||||
|
||||
def on_tpu():
|
||||
"""Returns True when building a TPU computation."""
|
||||
return tpu_function.get_tpu_context().number_of_shards is not None
|
||||
|
||||
|
||||
def cross_replica_mean(tensor, name=None):
|
||||
"""Takes mean value of a Tensor across all TPU cores.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be synchronized.
|
||||
name: None or string. Name of Op.
|
||||
|
||||
Returns:
|
||||
Average of Tensor across all TPU cores.
|
||||
|
||||
Raises:
|
||||
ValueError: If called outside of TPU context.
|
||||
"""
|
||||
with ops.name_scope(name, "cross_replica_mean", [tensor]):
|
||||
num_shards = tpu_function.get_tpu_context().number_of_shards
|
||||
if num_shards is None:
|
||||
raise ValueError(
|
||||
"Cannot take cross_replica_mean() outside of TPU Context.")
|
||||
if num_shards == 1:
|
||||
return tensor
|
||||
return tpu_ops.cross_replica_sum(tensor / num_shards)
|
||||
|
||||
|
||||
# TODO(b/69623235): Add a function for finding tensors that share gradients
|
||||
# to eliminate redundant fisher factor computations.
|
||||
|
@ -179,6 +179,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
|
||||
ConcatenateTensorBuffers<ArrayDataType::kInt64>(
|
||||
input_arrays, concatenation_axis, &concatenated_array);
|
||||
break;
|
||||
case ArrayDataType::kString:
|
||||
ConcatenateTensorBuffers<ArrayDataType::kString>(
|
||||
input_arrays, concatenation_axis, &concatenated_array);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "ArrayDataType not supported";
|
||||
}
|
||||
|
@ -52,6 +52,7 @@ using tensorflow::DT_BOOL;
|
||||
using tensorflow::DT_FLOAT;
|
||||
using tensorflow::DT_INT32;
|
||||
using tensorflow::DT_INT64;
|
||||
using tensorflow::DT_STRING;
|
||||
using tensorflow::DT_UINT8;
|
||||
using tensorflow::GraphDef;
|
||||
using tensorflow::NodeDef;
|
||||
@ -135,6 +136,8 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
|
||||
return ArrayDataType::kInt32;
|
||||
else if (dtype == DT_INT64)
|
||||
return ArrayDataType::kInt64;
|
||||
else if (dtype == DT_STRING)
|
||||
return ArrayDataType::kString;
|
||||
else
|
||||
LOG(INFO) << "Unsupported data type in placehoder op: " << dtype;
|
||||
return ArrayDataType::kNone;
|
||||
@ -236,6 +239,27 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
|
||||
}
|
||||
}
|
||||
|
||||
void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
|
||||
CHECK_EQ(input_tensor.dtype(), DT_STRING);
|
||||
const auto& input_shape = input_tensor.tensor_shape();
|
||||
CHECK_LE(input_shape.dim_size(), 4);
|
||||
ImportShape(input_shape.dim(), output_array->mutable_shape());
|
||||
int input_flat_size = 1;
|
||||
for (int k = 0; k < input_shape.dim_size(); k++) {
|
||||
input_flat_size *= input_shape.dim(k).size();
|
||||
}
|
||||
auto& output_string_data =
|
||||
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
|
||||
output_string_data.resize(input_flat_size);
|
||||
if (input_flat_size != input_tensor.string_val_size()) {
|
||||
LOG(FATAL) << "Input_content string_val doesn't have the right "
|
||||
"dimensions for this string tensor.";
|
||||
}
|
||||
for (int i = 0; i < input_flat_size; ++i) {
|
||||
output_string_data[i] = input_tensor.string_val(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Count the number of inputs of a given node. If
|
||||
// `tf_import_flags.drop_control_dependency` is true, count the number of
|
||||
// non-control-dependency inputs.
|
||||
@ -261,23 +285,30 @@ void ConvertConstOperator(const NodeDef& node,
|
||||
const auto dtype = GetDataTypeAttr(node, "dtype");
|
||||
|
||||
auto& array = model->GetOrCreateArray(node.name());
|
||||
array.data_type = dtype == DT_FLOAT
|
||||
? ArrayDataType::kFloat
|
||||
: dtype == DT_INT32
|
||||
? ArrayDataType::kInt32
|
||||
: dtype == DT_INT64 ? ArrayDataType::kInt64
|
||||
: ArrayDataType::kNone;
|
||||
if (dtype == DT_FLOAT) {
|
||||
ImportFloatArray(tensor, &array);
|
||||
} else if (dtype == DT_INT32) {
|
||||
ImportInt32Array(tensor, &array);
|
||||
} else if (dtype == DT_INT64) {
|
||||
ImportInt64Array(tensor, &array);
|
||||
} else {
|
||||
// do nothing, silently ignore the Const data. For example, there are consts
|
||||
// of string type. We just make a dummy buffer to indicate that this array
|
||||
// does not rely on external input.
|
||||
array.GetMutableBuffer<ArrayDataType::kNone>();
|
||||
switch (dtype) {
|
||||
case DT_FLOAT:
|
||||
array.data_type = ArrayDataType::kFloat;
|
||||
ImportFloatArray(tensor, &array);
|
||||
break;
|
||||
case DT_INT32:
|
||||
array.data_type = ArrayDataType::kInt32;
|
||||
ImportInt32Array(tensor, &array);
|
||||
break;
|
||||
case DT_INT64:
|
||||
array.data_type = ArrayDataType::kInt64;
|
||||
ImportInt64Array(tensor, &array);
|
||||
break;
|
||||
case DT_STRING:
|
||||
array.data_type = ArrayDataType::kString;
|
||||
ImportStringArray(tensor, &array);
|
||||
break;
|
||||
default:
|
||||
array.data_type = ArrayDataType::kNone;
|
||||
// do nothing, silently ignore the Const data.
|
||||
// We just make a dummy buffer to indicate that
|
||||
// this array does not rely on external input.
|
||||
array.GetMutableBuffer<ArrayDataType::kNone>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1191,7 +1222,7 @@ void ConvertGatherOperator(const NodeDef& node,
|
||||
CHECK_EQ(node.op(), "Gather");
|
||||
CHECK_EQ(GetInputsCount(node, tf_import_flags), 2);
|
||||
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
|
||||
CHECK(indices_data_type == DT_INT32);
|
||||
CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
|
||||
auto* op = new GatherOperator;
|
||||
op->inputs.push_back(node.input(0));
|
||||
op->inputs.push_back(node.input(1));
|
||||
|
@ -153,7 +153,15 @@ enum class AxesOrder {
|
||||
// because we'll be dropping the array anyway (e.g. some exotic array types
|
||||
// may be involved only in debug-only subgraphs that we may not be interested
|
||||
// in actually supporting).
|
||||
enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 };
|
||||
enum class ArrayDataType {
|
||||
kNone,
|
||||
kBool,
|
||||
kFloat,
|
||||
kUint8,
|
||||
kInt32,
|
||||
kInt64,
|
||||
kString
|
||||
};
|
||||
|
||||
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
|
||||
template <ArrayDataType A>
|
||||
@ -182,6 +190,10 @@ template <>
|
||||
struct DataTypeImpl<ArrayDataType::kInt64> {
|
||||
typedef int64 Type;
|
||||
};
|
||||
template <>
|
||||
struct DataTypeImpl<ArrayDataType::kString> {
|
||||
typedef string Type;
|
||||
};
|
||||
|
||||
template <ArrayDataType A>
|
||||
using DataType = typename DataTypeImpl<A>::Type;
|
||||
|
@ -53,6 +53,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
|
||||
return ::tflite::TensorType_INT32;
|
||||
case ArrayDataType::kUint8:
|
||||
return ::tflite::TensorType_UINT8;
|
||||
case ArrayDataType::kString:
|
||||
return ::tflite::TensorType_STRING;
|
||||
default:
|
||||
// FLOAT32 is filled for unknown data types.
|
||||
// TODO(ycling): Implement type inference in TF Lite interpreter.
|
||||
@ -66,6 +68,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
|
||||
return ArrayDataType::kFloat;
|
||||
case ::tflite::TensorType_INT32:
|
||||
return ArrayDataType::kInt32;
|
||||
case ::tflite::TensorType_STRING:
|
||||
return ArrayDataType::kString;
|
||||
case ::tflite::TensorType_UINT8:
|
||||
return ArrayDataType::kUint8;
|
||||
default:
|
||||
@ -82,6 +86,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
|
||||
return CopyBuffer<ArrayDataType::kFloat>(array, builder);
|
||||
case ArrayDataType::kInt32:
|
||||
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
|
||||
case ArrayDataType::kString:
|
||||
return CopyBuffer<ArrayDataType::kString>(array, builder);
|
||||
case ArrayDataType::kUint8:
|
||||
return CopyBuffer<ArrayDataType::kUint8>(array, builder);
|
||||
default:
|
||||
@ -99,6 +105,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
|
||||
return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
|
||||
case ::tflite::TensorType_INT32:
|
||||
return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
|
||||
case ::tflite::TensorType_STRING:
|
||||
return CopyBuffer<ArrayDataType::kString>(buffer, array);
|
||||
case ::tflite::TensorType_UINT8:
|
||||
return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
|
||||
default:
|
||||
|
@ -316,6 +316,9 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
||||
case ArrayDataType::kUint8:
|
||||
VLOG(log_level) << " Data type: kUint8";
|
||||
break;
|
||||
case ArrayDataType::kString:
|
||||
VLOG(log_level) << " Data type: kString";
|
||||
break;
|
||||
default:
|
||||
VLOG(log_level) << " Data type: other (numerical value: "
|
||||
<< static_cast<int>(array.data_type) << ")";
|
||||
@ -334,6 +337,9 @@ void LogArray(int log_level, const Model& model, const string& name) {
|
||||
case ArrayDataType::kUint8:
|
||||
VLOG(log_level) << " Final type: kUint8";
|
||||
break;
|
||||
case ArrayDataType::kString:
|
||||
VLOG(log_level) << " Final type: kString";
|
||||
break;
|
||||
default:
|
||||
VLOG(log_level) << " Final type: other (numerical value: "
|
||||
<< static_cast<int>(array.data_type) << ")";
|
||||
@ -1253,6 +1259,11 @@ int ElementSize(ArrayDataType data_type) {
|
||||
return 1;
|
||||
case ArrayDataType::kInt64:
|
||||
return 8;
|
||||
// Usually not critical limitation because strings are only input and/or
|
||||
// output.
|
||||
case ArrayDataType::kString:
|
||||
LOG(FATAL) << "Transient arrays with strings are not supported yet";
|
||||
return 0;
|
||||
default:
|
||||
LOG(FATAL) << "Should not get here.";
|
||||
return 0;
|
||||
|
@ -82,22 +82,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "utils_test",
|
||||
size = "small",
|
||||
srcs = ["python/saved_model/utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saved_model_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/saved_model:loader",
|
||||
"//tensorflow/python/saved_model:signature_constants",
|
||||
"//tensorflow/python/saved_model:tag_constants",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "UniqueDataset"
|
||||
summary: "Creates a dataset that contains the unique elements of `input_dataset`."
|
||||
}
|
@ -35,25 +35,41 @@ bool IsAdd(const NodeDef& node) {
|
||||
|
||||
bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
|
||||
|
||||
bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
|
||||
|
||||
bool IsAnyDiv(const NodeDef& node) {
|
||||
return node.op() == "RealDiv" || node.op() == "Div" ||
|
||||
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
|
||||
}
|
||||
|
||||
bool IsApproximateEqual(const NodeDef& node) {
|
||||
return node.op() == "ApproximateEqual";
|
||||
}
|
||||
|
||||
bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
|
||||
|
||||
bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
|
||||
|
||||
bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
|
||||
|
||||
bool IsBiasAdd(const NodeDef& node) {
|
||||
return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
|
||||
}
|
||||
|
||||
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
|
||||
|
||||
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
|
||||
|
||||
bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
|
||||
|
||||
bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
|
||||
|
||||
bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
|
||||
|
||||
bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
|
||||
|
||||
bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
|
||||
|
||||
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
|
||||
|
||||
bool IsConv2DBackpropFilter(const NodeDef& node) {
|
||||
@ -92,39 +108,77 @@ bool IsEnter(const NodeDef& node) {
|
||||
return op == "Enter" || op == "RefEnter";
|
||||
}
|
||||
|
||||
bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
|
||||
|
||||
bool IsExit(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Exit" || op == "RefExit";
|
||||
}
|
||||
|
||||
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
|
||||
|
||||
bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
|
||||
|
||||
bool IsFusedBatchNormGradV1(const NodeDef& node) {
|
||||
return node.op() == "FusedBatchNormGrad";
|
||||
bool IsFusedBatchNormGrad(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
|
||||
}
|
||||
|
||||
bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
|
||||
|
||||
bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
|
||||
|
||||
bool IsIdentity(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Identity" || op == "RefIdentity";
|
||||
}
|
||||
|
||||
bool IsIdentityN(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "IdentityN";
|
||||
}
|
||||
|
||||
bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
|
||||
|
||||
bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
|
||||
|
||||
bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
|
||||
|
||||
bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
|
||||
|
||||
bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
|
||||
|
||||
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
|
||||
|
||||
bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
|
||||
|
||||
bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
|
||||
|
||||
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
|
||||
|
||||
bool IsMatMul(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
|
||||
op == "SparseMatMul";
|
||||
}
|
||||
|
||||
bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
|
||||
|
||||
bool IsMerge(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Merge" || op == "RefMerge";
|
||||
}
|
||||
|
||||
bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
|
||||
|
||||
bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
|
||||
|
||||
bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
|
||||
|
||||
bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
|
||||
|
||||
bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
|
||||
|
||||
bool IsNextIteration(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "NextIteration" || op == "RefNextIteration";
|
||||
@ -138,6 +192,12 @@ bool IsPlaceholder(const NodeDef& node) {
|
||||
op == "PlaceholderWithDefault";
|
||||
}
|
||||
|
||||
bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
|
||||
|
||||
bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
|
||||
|
||||
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
|
||||
|
||||
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
|
||||
|
||||
bool IsReciprocalGrad(const NodeDef& node) {
|
||||
@ -209,12 +269,18 @@ bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
|
||||
|
||||
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
|
||||
|
||||
bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
|
||||
|
||||
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
|
||||
|
||||
bool IsVariable(const NodeDef& node) {
|
||||
const auto& op = node.op();
|
||||
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
|
||||
op == "VarHandleOp" || op == "ReadVariableOp";
|
||||
}
|
||||
|
||||
bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
|
||||
|
||||
namespace {
|
||||
bool GetBoolAttr(const NodeDef& node, const string& name) {
|
||||
return node.attr().count(name) > 0 && node.attr().at(name).b();
|
||||
@ -284,5 +350,10 @@ bool IsValuePreserving(const NodeDef& node) {
|
||||
return value_preserving_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
bool HasOpDef(const NodeDef& node) {
|
||||
const OpDef* op_def = nullptr;
|
||||
return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -24,11 +24,18 @@ namespace grappler {
|
||||
|
||||
bool IsAdd(const NodeDef& node);
|
||||
bool IsAddN(const NodeDef& node);
|
||||
bool IsAngle(const NodeDef& node);
|
||||
bool IsAnyDiv(const NodeDef& node);
|
||||
bool IsApproximateEqual(const NodeDef& node);
|
||||
bool IsAvgPoolGrad(const NodeDef& node);
|
||||
bool IsAssert(const NodeDef& node);
|
||||
bool IsAtan2(const NodeDef& node);
|
||||
bool IsBiasAdd(const NodeDef& node);
|
||||
bool IsBiasAddGrad(const NodeDef& node);
|
||||
bool IsBitcast(const NodeDef& node);
|
||||
bool IsComplex(const NodeDef& node);
|
||||
bool IsComplexAbs(const NodeDef& node);
|
||||
bool IsConj(const NodeDef& node);
|
||||
bool IsConcatOffset(const NodeDef& node);
|
||||
bool IsConstant(const NodeDef& node);
|
||||
bool IsConv2D(const NodeDef& node);
|
||||
@ -41,18 +48,38 @@ bool IsDequeueOp(const NodeDef& node);
|
||||
bool IsDiv(const NodeDef& node);
|
||||
bool IsEluGrad(const NodeDef& node);
|
||||
bool IsEnter(const NodeDef& node);
|
||||
bool IsEqual(const NodeDef& node);
|
||||
bool IsExit(const NodeDef& node);
|
||||
bool IsFloorDiv(const NodeDef& node);
|
||||
bool IsFloorMod(const NodeDef& node);
|
||||
bool IsFusedBatchNormGradV1(const NodeDef& node);
|
||||
bool IsFusedBatchNormGrad(const NodeDef& node);
|
||||
bool IsGreater(const NodeDef& node);
|
||||
bool IsGreaterEqual(const NodeDef& node);
|
||||
bool IsIdentity(const NodeDef& node);
|
||||
bool IsIdentityN(const NodeDef& node);
|
||||
bool IsIgamma(const NodeDef& node);
|
||||
bool IsIgammac(const NodeDef& node);
|
||||
bool IsImag(const NodeDef& node);
|
||||
bool IsInvGrad(const NodeDef& node);
|
||||
bool IsLess(const NodeDef& node);
|
||||
bool IsLessEqual(const NodeDef& node);
|
||||
bool IsLogicalAnd(const NodeDef& node);
|
||||
bool IsLogicalNot(const NodeDef& node);
|
||||
bool IsLogicalOr(const NodeDef& node);
|
||||
bool IsMaximum(const NodeDef& node);
|
||||
bool IsMerge(const NodeDef& node);
|
||||
bool IsMinimum(const NodeDef& node);
|
||||
bool IsMod(const NodeDef& node);
|
||||
bool IsMul(const NodeDef& node);
|
||||
bool IsMatMul(const NodeDef& node);
|
||||
bool IsNextIteration(const NodeDef& node);
|
||||
bool IsPad(const NodeDef& node);
|
||||
bool IsNoOp(const NodeDef& node);
|
||||
bool IsNotEqual(const NodeDef& node);
|
||||
bool IsPlaceholder(const NodeDef& node);
|
||||
bool IsPolygamma(const NodeDef& node);
|
||||
bool IsPow(const NodeDef& node);
|
||||
bool IsReal(const NodeDef& node);
|
||||
bool IsRealDiv(const NodeDef& node);
|
||||
bool IsRelu6Grad(const NodeDef& node);
|
||||
bool IsReluGrad(const NodeDef& node);
|
||||
@ -80,7 +107,10 @@ bool IsSum(const NodeDef& node);
|
||||
bool IsSwitch(const NodeDef& node);
|
||||
bool IsTanhGrad(const NodeDef& node);
|
||||
bool IsTranspose(const NodeDef& node);
|
||||
bool IsTruncateDiv(const NodeDef& node);
|
||||
bool IsTruncateMod(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
bool IsZeta(const NodeDef& node);
|
||||
|
||||
// Return true if the op is an aggregation (e.g. Add, AddN).
|
||||
// Returns false if it could not be determined to be so.
|
||||
@ -102,6 +132,9 @@ bool IsInvolution(const NodeDef& node);
|
||||
// function returns true if the op commutes with all element-wise operations.
|
||||
bool IsValuePreserving(const NodeDef& node);
|
||||
|
||||
// Returns true if we can find an opdef corresponding to the op of the node.
|
||||
bool HasOpDef(const NodeDef& node);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -350,15 +351,16 @@ Status DependencyOptimizer::TransitiveReduction() {
|
||||
num_nodes);
|
||||
for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
|
||||
const NodeDef& node = optimized_graph_->node(node_idx);
|
||||
if (ModifiesFrameInfo(node)) {
|
||||
// Ignore nodes that modify frame info.
|
||||
if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
|
||||
// Ignore function nodes and nodes that modify frame info.
|
||||
continue;
|
||||
}
|
||||
for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
|
||||
const string& input = node.input(input_slot);
|
||||
const NodeDef* input_node = node_map_->GetNode(input);
|
||||
if (ModifiesFrameInfo(*input_node)) {
|
||||
// Ignore edges from nodes that modify frame info.
|
||||
if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
|
||||
// Ignore edges from nodes that modify frame info and from Merge nodes,
|
||||
// because we cannot know which of it's input paths executes.
|
||||
continue;
|
||||
}
|
||||
const int input_node_idx = node_to_idx_[input_node];
|
||||
@ -375,6 +377,14 @@ Status DependencyOptimizer::TransitiveReduction() {
|
||||
// of length > 1, we can drop that control dependency.
|
||||
int num_controls_removed = 0;
|
||||
std::vector<int> longest_distance(num_nodes);
|
||||
// Map from target_index -> set of (input_slot, source_index), representing
|
||||
// the control edges to remove. We sort them in reverse order by input slot,
|
||||
// such that when we swap them out so we don't clobber the
|
||||
// node(target).input() repeated field.
|
||||
typedef std::pair<int, int> InputSlotAndSource;
|
||||
std::unordered_map<
|
||||
int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
|
||||
control_edges_to_remove;
|
||||
for (int source = 0; source < num_nodes; ++source) {
|
||||
int highest_control_target = -1;
|
||||
for (const auto& control_output : control_outputs[source]) {
|
||||
@ -382,7 +392,7 @@ Status DependencyOptimizer::TransitiveReduction() {
|
||||
highest_control_target = control_output.first;
|
||||
}
|
||||
}
|
||||
if (highest_control_target < source) {
|
||||
if (highest_control_target <= source) {
|
||||
continue;
|
||||
}
|
||||
std::fill(longest_distance.begin() + source,
|
||||
@ -391,7 +401,10 @@ Status DependencyOptimizer::TransitiveReduction() {
|
||||
for (int input : inputs[target]) {
|
||||
// If the input node is before source in the topo order, no path
|
||||
// source -> input -> target can exits and we can skip it.
|
||||
if (input >= source) {
|
||||
// Also only extend a path from the source itself or from nodes that
|
||||
// have a path from source, indicated by longest_distance[input] > 0.
|
||||
if (input == source ||
|
||||
(input > source && longest_distance[input] > 0)) {
|
||||
// If source -> input -> target is longer than the longest
|
||||
// path so far from source -> target, update the longest_distance.
|
||||
int candidate_longest_distance = longest_distance[input] + 1;
|
||||
@ -402,25 +415,36 @@ Status DependencyOptimizer::TransitiveReduction() {
|
||||
}
|
||||
}
|
||||
|
||||
// If the longest path from the source to the target of a control dependency
|
||||
// is longer than 1, there exists an alternate path, and we can eliminate
|
||||
// the control dependency since it is redundant.
|
||||
// If the longest path from source to target of a control dependency is
|
||||
// longer than 1, there exists an alternate path, and we can eliminate the
|
||||
// redundant direct control dependency.
|
||||
for (const auto& control_output : control_outputs[source]) {
|
||||
const int target = control_output.first;
|
||||
if (longest_distance[target] > 1) {
|
||||
const int input_slot = control_output.second;
|
||||
// We modify the node inplace here. This is safe because there can
|
||||
// only be one control edge from a given source to a given target.
|
||||
const NodeDef& source_node = optimized_graph_->node(source);
|
||||
NodeDef* target_node = optimized_graph_->mutable_node(target);
|
||||
target_node->mutable_input()->SwapElements(
|
||||
input_slot, target_node->input_size() - 1);
|
||||
node_map_->RemoveOutput(source_node.name(), target_node->name());
|
||||
target_node->mutable_input()->RemoveLast();
|
||||
++num_controls_removed;
|
||||
control_edges_to_remove[target].emplace(input_slot, source);
|
||||
VLOG(1) << "Removing edge from:\n"
|
||||
<< optimized_graph_->node(source).DebugString() << "\n\nto:\n\n"
|
||||
<< optimized_graph_->node(target).DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& it : control_edges_to_remove) {
|
||||
const int target = it.first;
|
||||
NodeDef* target_node = optimized_graph_->mutable_node(target);
|
||||
for (const InputSlotAndSource& slot_and_source : it.second) {
|
||||
const int input_slot = slot_and_source.first;
|
||||
const int source = slot_and_source.second;
|
||||
const NodeDef& source_node = optimized_graph_->node(source);
|
||||
CHECK_LT(input_slot, target_node->input_size());
|
||||
target_node->mutable_input()->SwapElements(input_slot,
|
||||
target_node->input_size() - 1);
|
||||
node_map_->RemoveOutput(source_node.name(), target_node->name());
|
||||
target_node->mutable_input()->RemoveLast();
|
||||
++num_controls_removed;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
|
||||
<< " control dependencies";
|
||||
return Status::OK();
|
||||
@ -442,36 +466,27 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
nodes_to_preserve_ = item.NodesToPreserve();
|
||||
fetch_nodes_known_ = !item.fetch.empty();
|
||||
|
||||
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
|
||||
CleanControlInputs();
|
||||
const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1;
|
||||
const int num_iterations = 2;
|
||||
for (int iteration = 0; iteration < num_iterations; ++iteration) {
|
||||
Status topo_sort_status;
|
||||
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
|
||||
// Prepare the graph for transitive reduction if enabled.
|
||||
topo_sort_status = TopologicalSort(optimized_graph_);
|
||||
}
|
||||
// Perform topological sort to prepare the graph for transitive reduction.
|
||||
topo_sort_status = TopologicalSort(optimized_graph_);
|
||||
|
||||
// Set up index-based graph datastructures to speed up analysis steps below.
|
||||
node_map_.reset(new NodeMap(optimized_graph_));
|
||||
BuildNodeToIdx();
|
||||
|
||||
// Remove redundant control dependencies, iteration 1.
|
||||
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
|
||||
if (topo_sort_status.ok()) {
|
||||
TF_RETURN_IF_ERROR(TransitiveReduction());
|
||||
} else {
|
||||
LOG(ERROR) << topo_sort_status.error_message();
|
||||
}
|
||||
VLOG(1) << "Graph after transitive reduction:\n"
|
||||
<< optimized_graph_->DebugString();
|
||||
if (topo_sort_status.ok()) {
|
||||
// Remove redundant control dependencies.
|
||||
TF_RETURN_IF_ERROR(TransitiveReduction());
|
||||
} else {
|
||||
LOG(ERROR) << topo_sort_status.error_message();
|
||||
}
|
||||
|
||||
// Turn nodes without non-control outputs into NoOps, prune NoOps.
|
||||
// Turn nodes with only control outputs into NoOps, prune NoOps.
|
||||
TF_RETURN_IF_ERROR(OptimizeDependencies());
|
||||
VLOG(1) << "Graph after NoOp conversion & pruning:\n"
|
||||
<< optimized_graph_->DebugString();
|
||||
}
|
||||
VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
|
||||
DependencyOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
@ -228,6 +228,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
|
||||
|
||||
// The optimization should be disabled to prevent increasing the number of
|
||||
// nodes crossing device boundaries.
|
||||
TF_CHECK_OK(TopologicalSort(&item.graph));
|
||||
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
|
||||
}
|
||||
|
||||
@ -282,7 +283,7 @@ TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch.push_back("id2");
|
||||
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
|
||||
DependencyOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
@ -60,7 +60,9 @@ std::set<string> GetOpsFormatSupported() {
|
||||
"DepthwiseConv2dNativeBackpropInput",
|
||||
"DepthwiseConv2dNativeBackpropFilter",
|
||||
"FusedBatchNorm",
|
||||
"FusedBatchNormV2",
|
||||
"FusedBatchNormGrad",
|
||||
"FusedBatchNormGradV2",
|
||||
"FusedConv2DBiasActivation",
|
||||
"MaxPool",
|
||||
"MaxPoolGrad",
|
||||
@ -75,52 +77,77 @@ std::set<string> GetOpsFormatAgnostic() {
|
||||
std::set<string> ops_format_agnostic = {"Abs",
|
||||
"Add",
|
||||
"AddN",
|
||||
"AddV2",
|
||||
"Acos",
|
||||
"Acosh",
|
||||
"Angle",
|
||||
"ApproximateEqual",
|
||||
"Asin",
|
||||
"Asinh",
|
||||
"Atan",
|
||||
"Atan2",
|
||||
"Atanh",
|
||||
"Bitcast",
|
||||
"Cast",
|
||||
"Ceil",
|
||||
"CheckNumerics",
|
||||
"Cos",
|
||||
"Cosh",
|
||||
"Complex",
|
||||
"ComplexAbs",
|
||||
"Concat",
|
||||
"ConcatV2",
|
||||
"Conj",
|
||||
"Cos",
|
||||
"Cosh",
|
||||
"Digamma",
|
||||
"Div",
|
||||
"Elu",
|
||||
"EluGrad",
|
||||
"Equal",
|
||||
"Erf",
|
||||
"Erfc",
|
||||
"Exp",
|
||||
"Expm1",
|
||||
"Floor",
|
||||
"FloorDiv",
|
||||
"FloorMod",
|
||||
"Greater",
|
||||
"GreaterEqual",
|
||||
"GuaranteeConst",
|
||||
"Identity",
|
||||
"IdentityN",
|
||||
"Igamma",
|
||||
"Igammac",
|
||||
"Imag",
|
||||
"Inv",
|
||||
"InvGrad",
|
||||
"IsFinite",
|
||||
"IsInf",
|
||||
"IsNan",
|
||||
"Less",
|
||||
"LessEqual",
|
||||
"Lgamma",
|
||||
"Log",
|
||||
"LogicalAnd",
|
||||
"LogicalNot",
|
||||
"LogicalOr",
|
||||
"Log1p",
|
||||
"Maximum",
|
||||
"Merge",
|
||||
"Minimum",
|
||||
"Mod",
|
||||
"Mul",
|
||||
"Neg",
|
||||
"NotEqual",
|
||||
"OnesLike",
|
||||
"Pad",
|
||||
"PreventGradient",
|
||||
"Polygamma",
|
||||
"Pow",
|
||||
"Real",
|
||||
"RealDiv",
|
||||
"Reciprocal",
|
||||
"ReciprocalGrad",
|
||||
"RefIdentity",
|
||||
"Relu",
|
||||
"Relu6",
|
||||
"Relu6Grad",
|
||||
@ -141,7 +168,8 @@ std::set<string> GetOpsFormatAgnostic() {
|
||||
"SoftplusGrad",
|
||||
"Split",
|
||||
"Switch",
|
||||
"RefIdentity",
|
||||
"TruncateDiv",
|
||||
"TruncateMod",
|
||||
"RefMerge",
|
||||
"RefSwitch",
|
||||
"Round",
|
||||
@ -157,7 +185,8 @@ std::set<string> GetOpsFormatAgnostic() {
|
||||
"Tan",
|
||||
"Tanh",
|
||||
"TanhGrad",
|
||||
"ZerosLike"};
|
||||
"ZerosLike",
|
||||
"Zeta"};
|
||||
return ops_format_agnostic;
|
||||
}
|
||||
|
||||
@ -212,6 +241,28 @@ bool IsUnaryGrad(const NodeDef& node) {
|
||||
return is_unary_grad;
|
||||
}
|
||||
|
||||
bool IsComparisonOp(const NodeDef& node) {
|
||||
bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
|
||||
IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
|
||||
IsLessEqual(node) || IsNotEqual(node);
|
||||
return is_compare;
|
||||
}
|
||||
|
||||
bool IsLogicalOp(const NodeDef& node) {
|
||||
return IsLogicalAnd(node) || IsLogicalNot(node) || IsLogicalOr(node);
|
||||
}
|
||||
|
||||
bool IsBinaryOp(const NodeDef& node) {
|
||||
bool is_binary =
|
||||
IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
|
||||
IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
|
||||
IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
|
||||
IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
|
||||
IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
|
||||
IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
|
||||
return is_binary;
|
||||
}
|
||||
|
||||
class GraphProcessor {
|
||||
public:
|
||||
GraphProcessor(const VirtualPlacer& virtual_placer,
|
||||
@ -409,17 +460,19 @@ class NodeProcessor : public GraphProcessor {
|
||||
|
||||
virtual void UpdateAttrShape() {
|
||||
if (node_->attr().find("_output_shapes") != node_->attr().end()) {
|
||||
auto shape = node_->mutable_attr()
|
||||
->at("_output_shapes")
|
||||
.mutable_list()
|
||||
->mutable_shape(0);
|
||||
if (shape->dim_size() == 4) {
|
||||
int64 h = shape->dim(1).size();
|
||||
int64 w = shape->dim(2).size();
|
||||
int64 c = shape->dim(3).size();
|
||||
shape->mutable_dim(1)->set_size(c);
|
||||
shape->mutable_dim(2)->set_size(h);
|
||||
shape->mutable_dim(3)->set_size(w);
|
||||
for (const auto& pos : GetOutputPos()) {
|
||||
auto shape = node_->mutable_attr()
|
||||
->at("_output_shapes")
|
||||
.mutable_list()
|
||||
->mutable_shape(pos);
|
||||
if (shape->dim_size() == 4) {
|
||||
int64 h = shape->dim(1).size();
|
||||
int64 w = shape->dim(2).size();
|
||||
int64 c = shape->dim(3).size();
|
||||
shape->mutable_dim(1)->set_size(c);
|
||||
shape->mutable_dim(2)->set_size(h);
|
||||
shape->mutable_dim(3)->set_size(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -603,13 +656,15 @@ class NodeProcessor : public GraphProcessor {
|
||||
added_node_name = AddPrefixToNodeName(added_node_base_name,
|
||||
kTransposeNCHWToNHWC, "-");
|
||||
DataType dtype;
|
||||
if (op == "Imag" || op == "Real" || op == "Angle" ||
|
||||
op == "Conj" || op == "ComplexAbs") {
|
||||
if (IsAngle(*node_) || IsComplex(*node_) ||
|
||||
IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) {
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout"));
|
||||
dtype = node_->attr().at("Tout").type();
|
||||
} else if (op == "Bitcast") {
|
||||
} else if (IsBitcast(*node_)) {
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "type"));
|
||||
dtype = node_->attr().at("type").type();
|
||||
} else if (IsLogicalOp(*node_) || IsComparisonOp(*node_)) {
|
||||
dtype = DT_BOOL;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
|
||||
dtype = node_->attr().at("T").type();
|
||||
@ -617,7 +672,8 @@ class NodeProcessor : public GraphProcessor {
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
|
||||
AddNodeTranspose(
|
||||
added_node_name, input, const_name, dtype,
|
||||
node_->attr().at("_output_shapes").list().shape(0), false);
|
||||
node_->attr().at("_output_shapes").list().shape(input_port),
|
||||
false);
|
||||
} else if (op == "DataFormatVecPermute") {
|
||||
added_node_name = AddPrefixToNodeName(added_node_base_name,
|
||||
kVecPermuteNCHWToNHWC, "-");
|
||||
@ -1002,11 +1058,10 @@ class AgnosticNodeProcessor : public NodeProcessor {
|
||||
if (IsConcatV1(node)) {
|
||||
return {1};
|
||||
}
|
||||
if (IsAdd(node) || IsMul(node) || IsRealDiv(node) ||
|
||||
IsSquaredDifference(node) || IsSub(node)) {
|
||||
if (IsBinaryOp(node) || IsUnaryGrad(node)) {
|
||||
return {0, 1};
|
||||
}
|
||||
if (IsShapeN(node)) {
|
||||
if (IsShapeN(node) || IsIdentityN(node)) {
|
||||
std::vector<int> pos;
|
||||
for (int i = 0; i < node.input_size(); i++) {
|
||||
pos.push_back(i);
|
||||
@ -1207,6 +1262,40 @@ class ConcatProcessor : public AgnosticNodeProcessor {
|
||||
int axis_node_pos_;
|
||||
};
|
||||
|
||||
class IdentityNProcessor : public AgnosticNodeProcessor {
|
||||
public:
|
||||
explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
|
||||
: AgnosticNodeProcessor(opt_cxt) {}
|
||||
|
||||
protected:
|
||||
bool ShouldProcess() const override {
|
||||
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
|
||||
IsOnGPU();
|
||||
}
|
||||
|
||||
std::vector<int> GetInputPos() const override {
|
||||
std::vector<int> input_pos;
|
||||
for (int i = 0; i < node_->input_size(); i++) {
|
||||
auto input = node_map_->GetNode(node_->input(i));
|
||||
int port;
|
||||
ParseNodeName(node_->input(i), &port);
|
||||
if (IsPortDimsFour(*input, port) &&
|
||||
(IsNodeAfterNCHWToNHWC(*input) || IsNodeNCHWToNHWC(input->name()))) {
|
||||
input_pos.push_back(i);
|
||||
}
|
||||
}
|
||||
return input_pos;
|
||||
}
|
||||
|
||||
std::set<int> GetOutputPos() const override {
|
||||
std::set<int> output_pos{};
|
||||
for (const auto& input_pos : GetInputPos()) {
|
||||
output_pos.insert(input_pos);
|
||||
}
|
||||
return output_pos;
|
||||
}
|
||||
};
|
||||
|
||||
class MergeProcessor : public AgnosticNodeProcessor {
|
||||
public:
|
||||
explicit MergeProcessor(const OptimizeContext& opt_cxt)
|
||||
@ -1371,12 +1460,15 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
|
||||
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
||||
|
||||
bool IsInputConvertible() const {
|
||||
int input_port;
|
||||
auto input = node_map_->GetNode(node_->input(0));
|
||||
ParseNodeName(node_->input(0), &input_port);
|
||||
if (IsNodeNCHWToNHWC(input->name())) {
|
||||
input = node_map_->GetNode(input->input(0));
|
||||
ParseNodeName(input->input(0), &input_port);
|
||||
}
|
||||
if (input->attr().find("_output_shapes") != input->attr().end()) {
|
||||
auto shape = input->attr().at("_output_shapes").list().shape(0);
|
||||
auto shape = input->attr().at("_output_shapes").list().shape(input_port);
|
||||
if (shape.dim_size() != 4) {
|
||||
return false;
|
||||
}
|
||||
@ -1529,7 +1621,7 @@ class DataLayoutOptimizer : GraphProcessor {
|
||||
new Conv2DBackpropFilterProcessor(opt_cxt, true));
|
||||
} else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
|
||||
node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
|
||||
} else if (IsFusedBatchNormGradV1(*node)) {
|
||||
} else if (IsFusedBatchNormGrad(*node)) {
|
||||
node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
|
||||
} else if (IsMaxPoolGradV1(*node)) {
|
||||
node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
|
||||
@ -1557,11 +1649,12 @@ class DataLayoutOptimizer : GraphProcessor {
|
||||
std::unique_ptr<NodeProcessor> node_processor;
|
||||
if (IsAddN(*node)) {
|
||||
node_processor.reset(new AddNProcessor(opt_cxt));
|
||||
} else if (IsAdd(*node) || IsMul(*node) || IsRealDiv(*node) ||
|
||||
IsSquaredDifference(*node) || IsSub(*node)) {
|
||||
} else if (IsBinaryOp(*node)) {
|
||||
node_processor.reset(new BinaryOpProcessor(opt_cxt));
|
||||
} else if (IsConcat(*node)) {
|
||||
node_processor.reset(new ConcatProcessor(opt_cxt));
|
||||
} else if (IsIdentityN(*node)) {
|
||||
node_processor.reset(new IdentityNProcessor(opt_cxt));
|
||||
} else if (IsMerge(*node)) {
|
||||
node_processor.reset(new MergeProcessor(opt_cxt));
|
||||
} else if (IsPad(*node)) {
|
||||
|
@ -1066,6 +1066,48 @@ TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
|
||||
"LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-1");
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, Complex) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
|
||||
auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
|
||||
auto i = ops::Identity(s.WithOpName("i"), comp);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
LayoutOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto merge_node = node_map.GetNode("complex");
|
||||
EXPECT_EQ(merge_node->input(0), "Conv2D");
|
||||
EXPECT_EQ(merge_node->input(1), "Conv2D");
|
||||
auto trans =
|
||||
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-complex-0-0");
|
||||
EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
|
||||
}
|
||||
|
||||
TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
|
||||
auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
|
||||
auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
|
||||
auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
LayoutOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||
NodeMap node_map(&output);
|
||||
auto i = node_map.GetNode("identity_n");
|
||||
EXPECT_EQ(i->input(0), "vector");
|
||||
EXPECT_EQ(i->input(1), "Conv2D");
|
||||
auto trans =
|
||||
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
|
||||
EXPECT_EQ(trans->input(0), "identity_n:1");
|
||||
auto add_node = node_map.GetNode("add");
|
||||
EXPECT_EQ(add_node->input(0), "identity_n");
|
||||
EXPECT_EQ(add_node->input(1),
|
||||
"LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -314,6 +314,11 @@ Status CudaSolver::forward_input_or_allocate_scoped_tensor(
|
||||
// are sometimes inaccurate, e.g., are missing 'const' on pointers
|
||||
// to immutable arguments, while the actual headers have them as expected.
|
||||
// Check the actual declarations in the cusolver_api.h header file.
|
||||
//
|
||||
// NOTE: The cuSolver functions called below appear not to be threadsafe.
|
||||
// so we put a global lock around the calls. Since these functions only put a
|
||||
// kernel on the shared stream, it is not a big performance hit.
|
||||
// TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
|
||||
//=============================================================================
|
||||
|
||||
template <typename Scalar, typename SolverFnT>
|
||||
@ -324,6 +329,7 @@ static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
|
||||
const Scalar* A, int lda,
|
||||
const Scalar* beta, /* host or device pointer */
|
||||
const Scalar* B, int ldb, Scalar* C, int ldc) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
|
||||
reinterpret_cast<const CudaScalar*>(alpha),
|
||||
@ -355,6 +361,7 @@ static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
cusolverDnHandle_t cusolver_dn_handle,
|
||||
cublasFillMode_t uplo, int n, Scalar* A, int lda,
|
||||
int* dev_lapack_info) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
@ -387,6 +394,7 @@ static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
cusolverDnHandle_t cusolver_dn_handle, int m,
|
||||
int n, Scalar* A, int lda, int* dev_pivots,
|
||||
int* dev_lapack_info) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
@ -419,9 +427,6 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
|
||||
cublasOperation_t trans, int n, int nrhs,
|
||||
const Scalar* A, int lda, const int* pivots,
|
||||
Scalar* B, int ldb, int* dev_lapack_info) {
|
||||
// Note: The cuSolver functions called here appear not to be threadsafe.
|
||||
// so we put a global lock around it. Since this function only puts a
|
||||
// kernel on the stream, it is not a big performance hit.
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Launch the solver kernel. */
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
|
||||
@ -449,6 +454,7 @@ static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
cusolverDnHandle_t cusolver_dn_handle, int m,
|
||||
int n, Scalar* A, int lda, Scalar* tau,
|
||||
int* dev_lapack_info) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
@ -483,6 +489,7 @@ static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
int m, int n, int k, const Scalar* dev_a,
|
||||
int lda, const Scalar* dev_tau, Scalar* dev_c,
|
||||
int ldc, int* dev_lapack_info) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(
|
||||
@ -526,6 +533,7 @@ static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
|
||||
cusolverDnHandle_t cusolver_dn_handle, int m,
|
||||
int n, int k, Scalar* dev_a, int lda,
|
||||
const Scalar* dev_tau, int* dev_lapack_info) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
|
||||
@ -606,17 +614,13 @@ static inline Status GesvdImpl(
|
||||
OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
|
||||
signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
|
||||
Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Get amount of workspace memory required. */
|
||||
int lwork;
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
|
||||
/* Allocate device memory for workspace. */
|
||||
auto dev_workspace =
|
||||
cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
|
||||
// Note: The cuSolver functions called here appear not to be threadsafe.
|
||||
// so we put a global lock around it. Since this function only puts a
|
||||
// kernel on the stream, it is not a big performance hit.
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
/* Launch the solver kernel. */
|
||||
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
|
||||
CUDAComplex(A), lda, S, CUDAComplex(U),
|
||||
ldu, CUDAComplex(VT), ldvt,
|
||||
@ -655,6 +659,7 @@ static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
|
||||
int lda, int* dev_pivots,
|
||||
DeviceLapackInfo* dev_lapack_info,
|
||||
int batch_size) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
|
||||
@ -689,6 +694,7 @@ static inline Status GetrsBatchedImpl(
|
||||
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
|
||||
const Scalar* const host_b_dev_ptrs[], int ldb,
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
|
||||
@ -734,6 +740,7 @@ static inline Status GetriBatchedImpl(
|
||||
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
|
||||
int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
|
||||
int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
|
||||
@ -776,6 +783,7 @@ static inline Status MatInvBatchedImpl(
|
||||
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
|
||||
int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
|
||||
DeviceLapackInfo* dev_lapack_info, int batch_size) {
|
||||
mutex_lock lock(handle_map_mutex);
|
||||
using CudaScalar = typename CUDAComplexT<Scalar>::type;
|
||||
ScratchSpace<uint8> dev_a_dev_ptrs =
|
||||
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
|
||||
|
@ -493,6 +493,18 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "unique_dataset_op",
|
||||
srcs = ["unique_dataset_op.cc"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dataset_ops",
|
||||
deps = [
|
||||
@ -526,6 +538,7 @@ tf_kernel_library(
|
||||
":take_dataset_op",
|
||||
":tensor_dataset_op",
|
||||
":tensor_slice_dataset_op",
|
||||
":unique_dataset_op",
|
||||
":zip_dataset_op",
|
||||
],
|
||||
)
|
||||
|
@ -55,10 +55,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
*output = nullptr;
|
||||
|
||||
#define HANDLE_TYPE(T) \
|
||||
case DataTypeToEnum<T>::value: { \
|
||||
*output = new Dataset<T>(batch_size, row_shape, input); \
|
||||
break; \
|
||||
#define HANDLE_TYPE(T) \
|
||||
case DataTypeToEnum<T>::value: { \
|
||||
*output = new Dataset<T>(ctx, batch_size, row_shape, input); \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (input->output_dtypes()[0]) {
|
||||
@ -75,11 +75,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
private:
|
||||
// TODO(mrry): Push the templated code down to the raw copying routine.
|
||||
template <class T>
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
Dataset(int64 batch_size, const PartialTensorShape& row_shape,
|
||||
const DatasetBase* input)
|
||||
: batch_size_(batch_size), row_shape_(row_shape), input_(input) {
|
||||
Dataset(OpKernelContext* ctx, int64 batch_size,
|
||||
const PartialTensorShape& row_shape, const DatasetBase* input)
|
||||
: GraphDatasetBase(ctx),
|
||||
batch_size_(batch_size),
|
||||
row_shape_(row_shape),
|
||||
input_(input) {
|
||||
input_->Ref();
|
||||
|
||||
output_shapes_.reserve(3);
|
||||
@ -112,6 +115,25 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
")::Dataset");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_node;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
|
||||
Node* batch_size_node;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
|
||||
Node* row_shape_node;
|
||||
std::vector<int64> row_shape;
|
||||
row_shape.reserve(
|
||||
row_shape_.dims()); // not an unknown rank PartialTensorShape
|
||||
for (int i = 0; i < row_shape_.dims(); i++)
|
||||
row_shape.emplace_back(row_shape_.dim_size(i));
|
||||
TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_node, batch_size_node, row_shape_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset<T>> {
|
||||
public:
|
||||
@ -242,6 +264,20 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(Iterator::SaveParent(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(Iterator::RestoreParent(ctx, reader, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
|
@ -64,20 +64,23 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::move(other_arguments),
|
||||
&captured_func));
|
||||
|
||||
*output =
|
||||
new Dataset(input, std::move(initial_state), std::move(captured_func),
|
||||
state_types_, output_types_, output_shapes_);
|
||||
*output = new Dataset(ctx, input, func_, std::move(initial_state),
|
||||
std::move(captured_func), state_types_, output_types_,
|
||||
output_shapes_);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
Dataset(const DatasetBase* input, std::vector<Tensor> initial_state,
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
const NameAttrList& func, std::vector<Tensor> initial_state,
|
||||
std::unique_ptr<CapturedFunction> captured_func,
|
||||
const DataTypeVector& state_types,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: input_(input),
|
||||
: GraphDatasetBase(ctx),
|
||||
input_(input),
|
||||
func_(func),
|
||||
initial_state_(std::move(initial_state)),
|
||||
captured_func_(std::move(captured_func)),
|
||||
state_types_(state_types),
|
||||
@ -103,6 +106,45 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
string DebugString() override { return "ScanDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
|
||||
Node* input_node;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
|
||||
std::vector<Node*> initial_state_nodes;
|
||||
initial_state_nodes.reserve(initial_state_.size());
|
||||
for (const Tensor& t : initial_state_) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
initial_state_nodes.emplace_back(node);
|
||||
}
|
||||
std::vector<Node*> other_arguments;
|
||||
other_arguments.reserve(captured_func_->captured_inputs().size());
|
||||
DataTypeVector other_arguments_types;
|
||||
other_arguments_types.reserve(captured_func_->captured_inputs().size());
|
||||
for (const Tensor& t : captured_func_->captured_inputs()) {
|
||||
Node* node;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
other_arguments.emplace_back(node);
|
||||
other_arguments_types.emplace_back(t.dtype());
|
||||
}
|
||||
AttrValue f;
|
||||
b->BuildAttrValue(func_, &f);
|
||||
AttrValue state_types;
|
||||
b->BuildAttrValue(state_types_, &state_types);
|
||||
AttrValue other_arguments_types_attr;
|
||||
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {{0, input_node}},
|
||||
{{1, initial_state_nodes}, {2, other_arguments}},
|
||||
{{"f", f},
|
||||
{"Tstate", state_types},
|
||||
{"Targuments", other_arguments_types_attr}},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
@ -185,6 +227,38 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
return s;
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
if (!state_.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("state_size"), state_.size()));
|
||||
for (int idx = 0; idx < state_.size(); idx++) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat("state[", idx, "]")), state_[idx]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
if (reader->Contains(full_name("state_size"))) {
|
||||
int64 size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
reader->ReadScalar(full_name("state_size"), &size));
|
||||
state_.resize(size);
|
||||
for (int idx = 0; idx < size; idx++) {
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat("state[", idx, "]")), &state_[idx]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
@ -192,6 +266,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const NameAttrList func_;
|
||||
const std::vector<Tensor> initial_state_;
|
||||
const std::unique_ptr<CapturedFunction> captured_func_;
|
||||
const DataTypeVector state_types_;
|
||||
|
219
tensorflow/core/kernels/data/unique_dataset_op.cc
Normal file
219
tensorflow/core/kernels/data/unique_dataset_op.cc
Normal file
@ -0,0 +1,219 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// See documentation in ../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
|
||||
class UniqueDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit UniqueDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
|
||||
errors::InvalidArgument("UniqueDataset only supports "
|
||||
"inputs with a single component."));
|
||||
|
||||
DataType input_dtype = input->output_dtypes()[0];
|
||||
OP_REQUIRES(ctx,
|
||||
input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
|
||||
input_dtype == DT_STRING,
|
||||
errors::InvalidArgument(
|
||||
"UniqueDataset only supports inputs with a single "
|
||||
"`tf.int32`, `tf.int64`, or `tf.string` component."));
|
||||
|
||||
*output = new Dataset(ctx, input);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public GraphDatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input)
|
||||
: GraphDatasetBase(ctx), input_(input) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIterator(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::Unique")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() override {
|
||||
return strings::StrCat("UniqueDatasetOp::Dataset");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const typename Iterator::Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
bool saw_new_value;
|
||||
do {
|
||||
saw_new_value = false;
|
||||
out_tensors->clear();
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
break;
|
||||
}
|
||||
DCHECK_EQ(1, out_tensors->size());
|
||||
saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
|
||||
} while (!saw_new_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name("unique_elements_size"), unique_elements_.size()));
|
||||
size_t i = 0;
|
||||
for (const Tensor& t : unique_elements_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||
full_name(strings::StrCat("unique_elements[", i++, "]")), t));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RestoreInternal(OpKernelContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||
} else {
|
||||
input_impl_.reset();
|
||||
}
|
||||
int64 num_unique_elements;
|
||||
unique_elements_.clear();
|
||||
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
|
||||
&num_unique_elements));
|
||||
for (int64 i = 0; i < num_unique_elements; ++i) {
|
||||
Tensor unique_element;
|
||||
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||
full_name(strings::StrCat("unique_elements[", i, "]")),
|
||||
&unique_element));
|
||||
auto insert_result = unique_elements_.insert(unique_element);
|
||||
if (!insert_result.second) {
|
||||
return errors::InvalidArgument(
|
||||
"Checkpoint contained two unique elements with the same "
|
||||
"value.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
struct TensorHash {
|
||||
size_t operator()(const Tensor& t) const {
|
||||
if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
|
||||
return Hash64(t.tensor_data().data(), t.tensor_data().size());
|
||||
} else {
|
||||
DCHECK_EQ(DT_STRING, t.dtype());
|
||||
auto flat_t = t.flat<string>();
|
||||
uint64 hash = 0;
|
||||
for (int64 i = 0; i < t.NumElements(); ++i) {
|
||||
hash = Hash64Combine(hash, Hash64(flat_t(i)));
|
||||
}
|
||||
return static_cast<size_t>(hash);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TensorKeyEqual {
|
||||
bool operator()(const Tensor& lhs, const Tensor& rhs) const {
|
||||
if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
|
||||
return false;
|
||||
}
|
||||
switch (lhs.dtype()) {
|
||||
#define HANDLE_TYPE(T) \
|
||||
case T: \
|
||||
do { \
|
||||
auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
|
||||
auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
|
||||
for (int64 i = 0; i < lhs.NumElements(); ++i) { \
|
||||
if (lhs_flat(i) != rhs_flat(i)) { \
|
||||
return false; \
|
||||
} \
|
||||
} \
|
||||
return true; \
|
||||
} while (0)
|
||||
|
||||
HANDLE_TYPE(DT_INT32);
|
||||
HANDLE_TYPE(DT_INT64);
|
||||
HANDLE_TYPE(DT_STRING);
|
||||
default:
|
||||
LOG(FATAL) << "UniqueDataset unhandled data type: "
|
||||
<< DataTypeString(lhs.dtype());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
|
||||
UniqueDatasetOp);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorflow
|
@ -558,6 +558,16 @@ filename: A path on the filesystem where we should cache the dataset. Note: this
|
||||
will be a directory.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("UniqueDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Creates a dataset that contains the unique elements of `input_dataset`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("TextLineDataset")
|
||||
.Input("filenames: string")
|
||||
.Input("compression_type: string")
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/platform/cloud/curl_http_request.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -327,6 +329,57 @@ Status CurlHttpRequest::SetResultBuffer(std::vector<char>* out_buffer) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) {
|
||||
CHECK(buffer != nullptr);
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
TF_RETURN_IF_ERROR(CheckNotSent());
|
||||
|
||||
direct_response_ = DirectResponseState{buffer, size, 0};
|
||||
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
|
||||
reinterpret_cast<void*>(this));
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION,
|
||||
&CurlHttpRequest::WriteCallbackDirect);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size,
|
||||
size_t nmemb, void* userdata) {
|
||||
CHECK(ptr != nullptr);
|
||||
auto that = reinterpret_cast<CurlHttpRequest*>(userdata);
|
||||
DirectResponseState* state = &that->direct_response_;
|
||||
CHECK(state->buffer_ != nullptr);
|
||||
CHECK(state->bytes_transferred_ <= state->buffer_size_);
|
||||
|
||||
size_t curl_bytes_received = size * nmemb;
|
||||
size_t user_buffer_bytes_available =
|
||||
state->buffer_size_ - state->bytes_transferred_;
|
||||
|
||||
// The HTTP server may send a response body that is longer than what we
|
||||
// expected. We must not use CHECK() for this situation, because that would
|
||||
// imply a code bug (in this client code) where none exists; the violation of
|
||||
// expectations would have been caused by the server, not the client. So we
|
||||
// report a log warning, if an HTTP server is misbehaving.
|
||||
if (curl_bytes_received > user_buffer_bytes_available) {
|
||||
LOG(WARNING) << "The HTTP response body that we received is longer than we "
|
||||
"requested or expected. "
|
||||
<< "Total bytes requested: " << state->buffer_size_
|
||||
<< " Bytes received (so far) in HTTP response body: "
|
||||
<< (state->bytes_transferred_ + curl_bytes_received);
|
||||
}
|
||||
|
||||
size_t bytes_to_copy =
|
||||
std::min<size_t>(curl_bytes_received, user_buffer_bytes_available);
|
||||
memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy);
|
||||
state->bytes_transferred_ += bytes_to_copy;
|
||||
return bytes_to_copy;
|
||||
}
|
||||
|
||||
size_t CurlHttpRequest::GetResultBufferDirectBytesTransferred() {
|
||||
CHECK(direct_response_.buffer_ != nullptr);
|
||||
return direct_response_.bytes_transferred_;
|
||||
}
|
||||
|
||||
Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity,
|
||||
uint32 total) {
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
|
@ -103,6 +103,26 @@ class CurlHttpRequest : public HttpRequest {
|
||||
/// read. Existing content of the vector will be cleared.
|
||||
Status SetResultBuffer(std::vector<char>* out_buffer) override;
|
||||
|
||||
/// \brief Specifies the buffer for receiving the response body, when the
|
||||
/// caller knows the maximum size of the response body.
|
||||
///
|
||||
/// This method allows the caller to receive the response body without an
|
||||
/// additional intermediate buffer allocation and copy. This method should
|
||||
/// be called before calling Send(). After Send() has succeeded, the caller
|
||||
/// should use the GetResultBufferDirectBytesTransferred() method in order
|
||||
/// to learn how many bytes were transferred.
|
||||
///
|
||||
/// Using this method is mutually exclusive with using SetResultBuffer().
|
||||
Status SetResultBufferDirect(char* buffer, size_t size) override;
|
||||
|
||||
/// \brief Returns the number of bytes (of the response body) that were
|
||||
/// transferred, when using the SetResultBufferDirect() method. The returned
|
||||
/// value will always be less than or equal to the 'size' parameter that
|
||||
/// was passed to SetResultBufferDirect(). If the actual HTTP response body
|
||||
/// was greater than 'size' bytes, then this transfer method will only copy
|
||||
/// the first 'size' bytes, and the rest will be ignored.
|
||||
size_t GetResultBufferDirectBytesTransferred() override;
|
||||
|
||||
/// \brief Returns the response headers of a completed request.
|
||||
///
|
||||
/// If the header is not found, returns an empty string.
|
||||
@ -127,6 +147,10 @@ class CurlHttpRequest : public HttpRequest {
|
||||
/// A write callback in the form which can be accepted by libcurl.
|
||||
static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
|
||||
void* userdata);
|
||||
|
||||
/// Processes response body content received when using SetResultBufferDirect.
|
||||
static size_t WriteCallbackDirect(const void* ptr, size_t size, size_t nmemb,
|
||||
void* userdata);
|
||||
/// A read callback in the form which can be accepted by libcurl.
|
||||
static size_t ReadCallback(void* ptr, size_t size, size_t nmemb,
|
||||
FILE* userdata);
|
||||
@ -150,6 +174,14 @@ class CurlHttpRequest : public HttpRequest {
|
||||
size_t post_body_read_ = 0;
|
||||
|
||||
std::vector<char>* response_buffer_ = nullptr;
|
||||
|
||||
struct DirectResponseState {
|
||||
char* buffer_;
|
||||
size_t buffer_size_;
|
||||
size_t bytes_transferred_;
|
||||
};
|
||||
DirectResponseState direct_response_ = {};
|
||||
|
||||
CURL* curl_ = nullptr;
|
||||
curl_slist* curl_headers_ = nullptr;
|
||||
curl_slist* resolve_list_ = nullptr;
|
||||
|
@ -288,6 +288,39 @@ TEST(CurlHttpRequestTest, GetRequest) {
|
||||
EXPECT_EQ(200, http_request.GetResponseCode());
|
||||
}
|
||||
|
||||
TEST(CurlHttpRequestTest, GetRequest_Direct) {
|
||||
FakeLibCurl libcurl("get response", 200);
|
||||
CurlHttpRequest http_request(&libcurl);
|
||||
TF_EXPECT_OK(http_request.Init());
|
||||
|
||||
std::vector<char> scratch(100, 0);
|
||||
|
||||
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
|
||||
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
|
||||
TF_EXPECT_OK(http_request.SetRange(100, 199));
|
||||
TF_EXPECT_OK(
|
||||
http_request.SetResultBufferDirect(scratch.data(), scratch.capacity()));
|
||||
TF_EXPECT_OK(http_request.Send());
|
||||
|
||||
string expected_response = "get response";
|
||||
size_t response_bytes_transferred =
|
||||
http_request.GetResultBufferDirectBytesTransferred();
|
||||
EXPECT_EQ(response_bytes_transferred, expected_response.size());
|
||||
EXPECT_EQ(
|
||||
"get response",
|
||||
string(scratch.begin(), scratch.begin() + response_bytes_transferred));
|
||||
|
||||
// Check interactions with libcurl.
|
||||
EXPECT_TRUE(libcurl.is_initialized_);
|
||||
EXPECT_EQ("http://www.testuri.com", libcurl.url_);
|
||||
EXPECT_EQ("100-199", libcurl.range_);
|
||||
EXPECT_EQ("", libcurl.custom_request_);
|
||||
EXPECT_EQ(1, libcurl.headers_->size());
|
||||
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl.headers_)[0]);
|
||||
EXPECT_FALSE(libcurl.is_post_);
|
||||
EXPECT_EQ(200, http_request.GetResponseCode());
|
||||
}
|
||||
|
||||
TEST(CurlHttpRequestTest, GetRequest_Empty) {
|
||||
FakeLibCurl libcurl("", 200);
|
||||
CurlHttpRequest http_request(&libcurl);
|
||||
|
@ -123,8 +123,12 @@ Status FileBlockCache::MaybeFetch(const Key& key,
|
||||
case FetchState::CREATED:
|
||||
block->state = FetchState::FETCHING;
|
||||
block->mu.unlock(); // Release the lock while making the API call.
|
||||
status.Update(
|
||||
block_fetcher_(key.first, key.second, block_size_, &block->data));
|
||||
block->data.clear();
|
||||
block->data.resize(block_size_, 0);
|
||||
size_t bytes_transferred;
|
||||
status.Update(block_fetcher_(key.first, key.second, block_size_,
|
||||
block->data.data(), &bytes_transferred));
|
||||
block->data.resize(bytes_transferred, 0);
|
||||
block->mu.lock(); // Reacquire the lock immediately afterwards
|
||||
if (status.ok()) {
|
||||
downloaded_block = true;
|
||||
@ -150,15 +154,15 @@ Status FileBlockCache::MaybeFetch(const Key& key,
|
||||
}
|
||||
|
||||
Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
out->clear();
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
*bytes_transferred = 0;
|
||||
if (n == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (block_size_ == 0 || max_bytes_ == 0) {
|
||||
// The cache is effectively disabled, so we pass the read through to the
|
||||
// fetcher without breaking it up into blocks.
|
||||
return block_fetcher_(filename, offset, n, out);
|
||||
return block_fetcher_(filename, offset, n, buffer, bytes_transferred);
|
||||
}
|
||||
// Calculate the block-aligned start and end of the read.
|
||||
size_t start = block_size_ * (offset / block_size_);
|
||||
@ -166,6 +170,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
|
||||
if (finish < offset + n) {
|
||||
finish += block_size_;
|
||||
}
|
||||
size_t total_bytes_transferred = 0;
|
||||
// Now iterate through the blocks, reading them one at a time.
|
||||
for (size_t pos = start; pos < finish; pos += block_size_) {
|
||||
Key key = std::make_pair(filename, pos);
|
||||
@ -181,6 +186,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
|
||||
// The requested offset is at or beyond the end of the file. This can
|
||||
// happen if `offset` is not block-aligned, and the read returns the last
|
||||
// block in the file, which does not extend all the way out to `offset`.
|
||||
*bytes_transferred = total_bytes_transferred;
|
||||
return errors::OutOfRange("EOF at offset ", offset, " in file ", filename,
|
||||
" at position ", pos, "with data size ",
|
||||
data.size());
|
||||
@ -196,13 +202,16 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
|
||||
end -= (pos + data.size()) - (offset + n);
|
||||
}
|
||||
if (begin < end) {
|
||||
out->insert(out->end(), begin, end);
|
||||
size_t bytes_to_copy = end - begin;
|
||||
memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy);
|
||||
total_bytes_transferred += bytes_to_copy;
|
||||
}
|
||||
if (data.size() < block_size_) {
|
||||
// The block was a partial block and thus signals EOF at its upper bound.
|
||||
break;
|
||||
}
|
||||
}
|
||||
*bytes_transferred = total_bytes_transferred;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -43,8 +43,9 @@ class FileBlockCache {
|
||||
/// cache is constructed. The returned Status should be OK as long as the
|
||||
/// read from the remote filesystem succeeded (similar to the semantics of the
|
||||
/// read(2) system call).
|
||||
typedef std::function<Status(const string&, size_t, size_t,
|
||||
std::vector<char>*)>
|
||||
typedef std::function<Status(const string& filename, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
size_t* bytes_transferred)>
|
||||
BlockFetcher;
|
||||
|
||||
FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness,
|
||||
@ -83,8 +84,8 @@ class FileBlockCache {
|
||||
/// placed in `out`.
|
||||
/// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
|
||||
/// in `out`).
|
||||
Status Read(const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out);
|
||||
Status Read(const string& filename, size_t offset, size_t n, char* buffer,
|
||||
size_t* bytes_transferred);
|
||||
|
||||
/// Remove all cached blocks for `filename`.
|
||||
void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_);
|
||||
|
@ -25,6 +25,18 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ReadCache(FileBlockCache* cache, const string& filename, size_t offset,
|
||||
size_t n, std::vector<char>* out) {
|
||||
out->clear();
|
||||
out->resize(n, 0);
|
||||
size_t bytes_transferred = 0;
|
||||
Status status =
|
||||
cache->Read(filename, offset, n, out->data(), &bytes_transferred);
|
||||
EXPECT_LE(bytes_transferred, n);
|
||||
out->resize(bytes_transferred, n);
|
||||
return status;
|
||||
}
|
||||
|
||||
TEST(FileBlockCacheTest, PassThrough) {
|
||||
const string want_filename = "foo/bar";
|
||||
const size_t want_offset = 42;
|
||||
@ -32,12 +44,13 @@ TEST(FileBlockCacheTest, PassThrough) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls, want_filename, want_offset, want_n](
|
||||
const string& got_filename, size_t got_offset,
|
||||
size_t got_n, std::vector<char>* out) {
|
||||
size_t got_n, char* buffer, size_t* bytes_transferred) {
|
||||
EXPECT_EQ(got_filename, want_filename);
|
||||
EXPECT_EQ(got_offset, want_offset);
|
||||
EXPECT_EQ(got_n, want_n);
|
||||
calls++;
|
||||
out->resize(got_n, 'x');
|
||||
memset(buffer, 'x', got_n);
|
||||
*bytes_transferred = got_n;
|
||||
return Status::OK();
|
||||
};
|
||||
// If block_size, max_bytes, or both are zero, the cache is a pass-through.
|
||||
@ -45,11 +58,11 @@ TEST(FileBlockCacheTest, PassThrough) {
|
||||
FileBlockCache cache2(0, 1, 0, fetcher);
|
||||
FileBlockCache cache3(0, 0, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(cache1.Read(want_filename, want_offset, want_n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
TF_EXPECT_OK(cache2.Read(want_filename, want_offset, want_n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 2);
|
||||
TF_EXPECT_OK(cache3.Read(want_filename, want_offset, want_n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out));
|
||||
EXPECT_EQ(calls, 3);
|
||||
}
|
||||
|
||||
@ -63,13 +76,13 @@ TEST(FileBlockCacheTest, BlockAlignment) {
|
||||
}
|
||||
// The fetcher just fetches slices of the buffer.
|
||||
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
if (offset < buf.size()) {
|
||||
if (offset + n > buf.size()) {
|
||||
out->insert(out->end(), buf.begin() + offset, buf.end());
|
||||
} else {
|
||||
out->insert(out->end(), buf.begin() + offset, buf.begin() + offset + n);
|
||||
}
|
||||
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
|
||||
memcpy(buffer, buf.data() + offset, bytes_to_copy);
|
||||
*bytes_transferred = bytes_to_copy;
|
||||
} else {
|
||||
*bytes_transferred = 0;
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
@ -80,7 +93,7 @@ TEST(FileBlockCacheTest, BlockAlignment) {
|
||||
for (size_t offset = 0; offset < 10; offset++) {
|
||||
for (size_t n = block_size - 2; n <= block_size + 2; n++) {
|
||||
std::vector<char> got;
|
||||
TF_EXPECT_OK(cache.Read("", offset, n, &got));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got));
|
||||
// Verify the size of the read.
|
||||
if (offset + n <= size) {
|
||||
// Expect a full read.
|
||||
@ -108,24 +121,27 @@ TEST(FileBlockCacheTest, CacheHits) {
|
||||
const size_t block_size = 16;
|
||||
std::set<size_t> calls;
|
||||
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
|
||||
size_t n, std::vector<char>* out) {
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
|
||||
calls.insert(offset);
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return Status::OK();
|
||||
};
|
||||
const uint32 block_count = 256;
|
||||
FileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
out.resize(block_count, 0);
|
||||
// The cache has space for `block_count` blocks. The loop with i = 0 should
|
||||
// fill the cache, and the loop with i = 1 should be all cache hits. The
|
||||
// fetcher checks that it is called once and only once for each offset (to
|
||||
// fetch the corresponding block).
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < block_count; j++) {
|
||||
TF_EXPECT_OK(cache.Read("", block_size * j, block_size, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -138,36 +154,39 @@ TEST(FileBlockCacheTest, OutOfRange) {
|
||||
bool second_block = false;
|
||||
auto fetcher = [block_size, file_size, &first_block, &second_block](
|
||||
const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
size_t bytes_to_copy = 0;
|
||||
if (offset == 0) {
|
||||
// The first block (16 bytes) of the file.
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
bytes_to_copy = n;
|
||||
first_block = true;
|
||||
} else if (offset == block_size) {
|
||||
// The second block (8 bytes) of the file.
|
||||
out->resize(file_size - block_size, 'x');
|
||||
bytes_to_copy = file_size - block_size;
|
||||
memset(buffer, 'x', bytes_to_copy);
|
||||
second_block = true;
|
||||
}
|
||||
*bytes_transferred = bytes_to_copy;
|
||||
return Status::OK();
|
||||
};
|
||||
FileBlockCache cache(block_size, block_size, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
// Reading the first 16 bytes should be fine.
|
||||
TF_EXPECT_OK(cache.Read("", 0, block_size, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out));
|
||||
EXPECT_TRUE(first_block);
|
||||
EXPECT_EQ(out.size(), block_size);
|
||||
// Reading at offset file_size + 4 will read the second block (since the read
|
||||
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return
|
||||
// OutOfRange because the offset is past the end of the 24-byte file.
|
||||
Status status = cache.Read("", file_size + 4, 4, &out);
|
||||
Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
|
||||
EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
|
||||
EXPECT_TRUE(second_block);
|
||||
EXPECT_EQ(out.size(), 0);
|
||||
// Reading the second full block will return 8 bytes, from a cache hit.
|
||||
second_block = false;
|
||||
TF_EXPECT_OK(cache.Read("", block_size, block_size, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
|
||||
EXPECT_FALSE(second_block);
|
||||
EXPECT_EQ(out.size(), file_size - block_size);
|
||||
}
|
||||
@ -178,20 +197,22 @@ TEST(FileBlockCacheTest, Inconsistent) {
|
||||
const size_t block_size = 16;
|
||||
// This fetcher returns OK but only fills in one byte for any offset.
|
||||
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset % block_size, 0);
|
||||
out->resize(1, 'x');
|
||||
EXPECT_GE(n, 1);
|
||||
memset(buffer, 'x', 1);
|
||||
*bytes_transferred = 1;
|
||||
return Status::OK();
|
||||
};
|
||||
FileBlockCache cache(block_size, 2 * block_size, 0, fetcher);
|
||||
std::vector<char> out;
|
||||
// Read the second block; this should yield an OK status and a single byte.
|
||||
TF_EXPECT_OK(cache.Read("", block_size, block_size, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
|
||||
EXPECT_EQ(out.size(), 1);
|
||||
// Now read the first block; this should yield an INTERNAL error because we
|
||||
// had already cached a partial block at a later position.
|
||||
Status status = cache.Read("", 0, block_size, &out);
|
||||
Status status = ReadCache(&cache, "", 0, block_size, &out);
|
||||
EXPECT_EQ(status.code(), error::INTERNAL);
|
||||
}
|
||||
|
||||
@ -199,14 +220,16 @@ TEST(FileBlockCacheTest, LRU) {
|
||||
const size_t block_size = 16;
|
||||
std::list<size_t> calls;
|
||||
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
|
||||
size_t n, std::vector<char>* out) {
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
|
||||
if (!calls.empty()) {
|
||||
EXPECT_EQ(offset, calls.front());
|
||||
calls.pop_front();
|
||||
}
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return Status::OK();
|
||||
};
|
||||
const uint32 block_count = 2;
|
||||
@ -216,38 +239,39 @@ TEST(FileBlockCacheTest, LRU) {
|
||||
// fetcher calls that the cache makes.
|
||||
calls.push_back(0);
|
||||
// Cache miss - drains an element from `calls`.
|
||||
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
// Cache hit - does not drain an element from `calls`.
|
||||
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
calls.push_back(block_size);
|
||||
// Cache miss followed by cache hit.
|
||||
TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
|
||||
TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
|
||||
calls.push_back(2 * block_size);
|
||||
// Cache miss followed by cache hit. Causes eviction of LRU element.
|
||||
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
|
||||
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
|
||||
// LRU element was at offset 0. Cache miss.
|
||||
calls.push_back(0);
|
||||
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
// Element at 2 * block_size is still in cache, and this read should update
|
||||
// its position in the LRU list so it doesn't get evicted by the next read.
|
||||
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
|
||||
// Element at block_size was evicted. Reading this element will also cause
|
||||
// the LRU element (at 0) to be evicted.
|
||||
calls.push_back(block_size);
|
||||
TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
|
||||
// Element at 0 was evicted again.
|
||||
calls.push_back(0);
|
||||
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
|
||||
}
|
||||
|
||||
TEST(FileBlockCacheTest, MaxStaleness) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
calls++;
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return Status::OK();
|
||||
};
|
||||
std::vector<char> out;
|
||||
@ -256,14 +280,14 @@ TEST(FileBlockCacheTest, MaxStaleness) {
|
||||
// expected.
|
||||
FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get());
|
||||
// Execute the first read to load the block.
|
||||
TF_EXPECT_OK(cache1.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
// Now advance the clock one second at a time and redo the read. The call
|
||||
// count should advance every 3 seconds (i.e. every time the staleness is
|
||||
// greater than 2).
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
env->SetNowSeconds(i + 1);
|
||||
TF_EXPECT_OK(cache1.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1 + i / 3);
|
||||
}
|
||||
// Now create a cache with max staleness of 0, and verify that it also works
|
||||
@ -272,27 +296,27 @@ TEST(FileBlockCacheTest, MaxStaleness) {
|
||||
env->SetNowSeconds(0);
|
||||
FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get());
|
||||
// Execute the first read to load the block.
|
||||
TF_EXPECT_OK(cache2.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
// Advance the clock by a huge amount and verify that the cached block is
|
||||
// used to satisfy the read.
|
||||
env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun.
|
||||
TF_EXPECT_OK(cache2.Read("", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 1);
|
||||
}
|
||||
|
||||
TEST(FileBlockCacheTest, RemoveFile) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
calls++;
|
||||
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
|
||||
if (offset > 0) {
|
||||
// The first block is lower case and all subsequent blocks are upper case.
|
||||
c = toupper(c);
|
||||
}
|
||||
out->clear();
|
||||
out->resize(n, c);
|
||||
memset(buffer, c, n);
|
||||
*bytes_transferred = n;
|
||||
return Status::OK();
|
||||
};
|
||||
// This cache has space for 4 blocks; we'll read from two files.
|
||||
@ -304,41 +328,41 @@ TEST(FileBlockCacheTest, RemoveFile) {
|
||||
std::vector<char> A(n, 'A');
|
||||
std::vector<char> B(n, 'B');
|
||||
// Fill the cache.
|
||||
TF_EXPECT_OK(cache.Read("a", 0, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
|
||||
EXPECT_EQ(out, a);
|
||||
EXPECT_EQ(calls, 1);
|
||||
TF_EXPECT_OK(cache.Read("a", 8, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
|
||||
EXPECT_EQ(out, A);
|
||||
EXPECT_EQ(calls, 2);
|
||||
TF_EXPECT_OK(cache.Read("b", 0, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
|
||||
EXPECT_EQ(out, b);
|
||||
EXPECT_EQ(calls, 3);
|
||||
TF_EXPECT_OK(cache.Read("b", 8, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
|
||||
EXPECT_EQ(out, B);
|
||||
EXPECT_EQ(calls, 4);
|
||||
// All four blocks should be in the cache now.
|
||||
TF_EXPECT_OK(cache.Read("a", 0, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
|
||||
EXPECT_EQ(out, a);
|
||||
TF_EXPECT_OK(cache.Read("a", 8, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
|
||||
EXPECT_EQ(out, A);
|
||||
TF_EXPECT_OK(cache.Read("b", 0, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
|
||||
EXPECT_EQ(out, b);
|
||||
TF_EXPECT_OK(cache.Read("b", 8, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
|
||||
EXPECT_EQ(out, B);
|
||||
EXPECT_EQ(calls, 4);
|
||||
// Remove the blocks from "a".
|
||||
cache.RemoveFile("a");
|
||||
// Both blocks from "b" should still be there.
|
||||
TF_EXPECT_OK(cache.Read("b", 0, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
|
||||
EXPECT_EQ(out, b);
|
||||
TF_EXPECT_OK(cache.Read("b", 8, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
|
||||
EXPECT_EQ(out, B);
|
||||
EXPECT_EQ(calls, 4);
|
||||
// The blocks from "a" should not be there.
|
||||
TF_EXPECT_OK(cache.Read("a", 0, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
|
||||
EXPECT_EQ(out, a);
|
||||
EXPECT_EQ(calls, 5);
|
||||
TF_EXPECT_OK(cache.Read("a", 8, n, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
|
||||
EXPECT_EQ(out, A);
|
||||
EXPECT_EQ(calls, 6);
|
||||
}
|
||||
@ -346,10 +370,10 @@ TEST(FileBlockCacheTest, RemoveFile) {
|
||||
TEST(FileBlockCacheTest, Prune) {
|
||||
int calls = 0;
|
||||
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
calls++;
|
||||
out->clear();
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return Status::OK();
|
||||
};
|
||||
std::vector<char> out;
|
||||
@ -360,20 +384,20 @@ TEST(FileBlockCacheTest, Prune) {
|
||||
FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get());
|
||||
// Read three blocks into the cache, and advance the timestamp by one second
|
||||
// with each read. Start with a block of "a" at the current timestamp `now`.
|
||||
TF_EXPECT_OK(cache.Read("a", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
|
||||
// Now load a block of a different file "b" at timestamp `now` + 1
|
||||
env->SetNowSeconds(now + 1);
|
||||
TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
|
||||
// Now load a different block of file "a" at timestamp `now` + 1. When the
|
||||
// first block of "a" expires, this block should also be removed because it
|
||||
// also belongs to file "a".
|
||||
TF_EXPECT_OK(cache.Read("a", 8, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
|
||||
// Ensure that all blocks are in the cache (i.e. reads are cache hits).
|
||||
EXPECT_EQ(cache.CacheSize(), 24);
|
||||
EXPECT_EQ(calls, 3);
|
||||
TF_EXPECT_OK(cache.Read("a", 0, 1, &out));
|
||||
TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
|
||||
TF_EXPECT_OK(cache.Read("a", 8, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
|
||||
EXPECT_EQ(calls, 3);
|
||||
// Advance the fake timestamp so that "a" becomes stale via its first block.
|
||||
env->SetNowSeconds(now + 2);
|
||||
@ -389,7 +413,7 @@ TEST(FileBlockCacheTest, Prune) {
|
||||
// There should be one block left in the cache, and it should be the first
|
||||
// block of "b".
|
||||
EXPECT_EQ(cache.CacheSize(), 8);
|
||||
TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
|
||||
EXPECT_EQ(calls, 3);
|
||||
// Advance the fake time to `now` + 3, at which point "b" becomes stale.
|
||||
env->SetNowSeconds(now + 3);
|
||||
@ -409,14 +433,14 @@ TEST(FileBlockCacheTest, ParallelReads) {
|
||||
const int callers = 4;
|
||||
BlockingCounter counter(callers);
|
||||
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
counter.DecrementCount();
|
||||
if (!counter.WaitFor(std::chrono::seconds(10))) {
|
||||
// This avoids having the test time out, which is harder to debug.
|
||||
return errors::FailedPrecondition("desired concurrency not reached");
|
||||
}
|
||||
out->clear();
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
return Status::OK();
|
||||
};
|
||||
const int block_size = 8;
|
||||
@ -426,7 +450,8 @@ TEST(FileBlockCacheTest, ParallelReads) {
|
||||
threads.emplace_back(
|
||||
Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() {
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(cache.Read("a", i * block_size, block_size, &out));
|
||||
TF_EXPECT_OK(
|
||||
ReadCache(&cache, "a", i * block_size, block_size, &out));
|
||||
std::vector<char> x(block_size, 'x');
|
||||
EXPECT_EQ(out, x);
|
||||
}));
|
||||
@ -443,11 +468,12 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
|
||||
Notification notification;
|
||||
auto fetcher = [&num_requests, ¬ification, block_size](
|
||||
const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
EXPECT_EQ(n, block_size);
|
||||
EXPECT_EQ(offset, 0);
|
||||
num_requests++;
|
||||
out->resize(n, 'x');
|
||||
memset(buffer, 'x', n);
|
||||
*bytes_transferred = n;
|
||||
notification.Notify();
|
||||
// Wait for other thread to issue read.
|
||||
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
|
||||
@ -458,17 +484,16 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
|
||||
std::unique_ptr<Thread> concurrent(
|
||||
Env::Default()->StartThread({}, "concurrent", [&cache, block_size] {
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(cache.Read("", 0, block_size / 2, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out));
|
||||
EXPECT_EQ(out.size(), block_size / 2);
|
||||
}));
|
||||
EXPECT_TRUE(WaitForNotificationWithTimeout(¬ification, 10000))
|
||||
<< "Timeout waiting for concurrent thread to start.";
|
||||
std::vector<char> out;
|
||||
TF_EXPECT_OK(cache.Read("", block_size / 2, block_size / 2, &out));
|
||||
TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out));
|
||||
EXPECT_EQ(out.size(), block_size / 2);
|
||||
|
||||
EXPECT_EQ(1, num_requests);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -58,6 +58,10 @@ class TestHttpRequest : public HttpRequest {
|
||||
Status SetResultBuffer(std::vector<char>* out_buffer) override {
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetResultBufferDirect(char* buffer, size_t size) override {
|
||||
return Status::OK();
|
||||
}
|
||||
size_t GetResultBufferDirectBytesTransferred() override { return 0; }
|
||||
|
||||
string GetResponseHeader(const string& name) const override { return ""; }
|
||||
uint64 GetResponseCode() const override { return 0; }
|
||||
|
@ -285,11 +285,11 @@ class GcsRandomAccessFile : public RandomAccessFile {
|
||||
Status Read(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const override {
|
||||
*result = StringPiece();
|
||||
std::vector<char> out;
|
||||
TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out));
|
||||
std::memcpy(scratch, out.data(), std::min(out.size(), n));
|
||||
*result = StringPiece(scratch, std::min(out.size(), n));
|
||||
if (result->size() < n) {
|
||||
size_t bytes_transferred;
|
||||
TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch,
|
||||
&bytes_transferred));
|
||||
*result = StringPiece(scratch, bytes_transferred);
|
||||
if (bytes_transferred < n) {
|
||||
// This is not an error per se. The RandomAccessFile interface expects
|
||||
// that Read returns OutOfRange if fewer bytes were read than requested.
|
||||
return errors::OutOfRange("EOF reached, ", result->size(),
|
||||
@ -721,15 +721,17 @@ std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
|
||||
std::unique_ptr<FileBlockCache> file_block_cache(
|
||||
new FileBlockCache(block_size, max_bytes, max_staleness,
|
||||
[this](const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out) {
|
||||
return LoadBufferFromGCS(filename, offset, n, out);
|
||||
char* buffer, size_t* bytes_transferred) {
|
||||
return LoadBufferFromGCS(filename, offset, n, buffer,
|
||||
bytes_transferred);
|
||||
}));
|
||||
return file_block_cache;
|
||||
}
|
||||
|
||||
// A helper function to actually read the data from GCS.
|
||||
Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
|
||||
size_t n, std::vector<char>* out) {
|
||||
size_t n, char* buffer,
|
||||
size_t* bytes_transferred) {
|
||||
string bucket, object;
|
||||
TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object));
|
||||
|
||||
@ -739,21 +741,23 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
|
||||
request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket,
|
||||
"/", request->EscapeString(object))));
|
||||
TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
|
||||
TF_RETURN_IF_ERROR(request->SetResultBuffer(out));
|
||||
TF_RETURN_IF_ERROR(request->SetResultBufferDirect(buffer, n));
|
||||
TF_RETURN_IF_ERROR(
|
||||
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read));
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
|
||||
bucket, "/", object);
|
||||
|
||||
size_t bytes_read = request->GetResultBufferDirectBytesTransferred();
|
||||
*bytes_transferred = bytes_read;
|
||||
VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
|
||||
<< offset << " of size: " << out->size();
|
||||
<< offset << " of size: " << bytes_read;
|
||||
|
||||
if (out->size() < block_size()) {
|
||||
if (bytes_read < block_size()) {
|
||||
// Check stat cache to see if we encountered an interrupted read.
|
||||
FileStatistics stat;
|
||||
if (stat_cache_->Lookup(filename, &stat)) {
|
||||
if (offset + out->size() < stat.length) {
|
||||
if (offset + bytes_read < stat.length) {
|
||||
return errors::Internal(strings::Printf(
|
||||
"File contents are inconsistent for file: %s @ %lu.",
|
||||
filename.c_str(), offset));
|
||||
|
@ -177,7 +177,7 @@ class GcsFileSystem : public FileSystem {
|
||||
|
||||
/// Loads file contents from GCS for a given filename, offset, and length.
|
||||
Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n,
|
||||
std::vector<char>* out);
|
||||
char* buffer, size_t* bytes_transferred);
|
||||
|
||||
std::unique_ptr<AuthProvider> auth_provider_;
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
|
||||
|
@ -101,6 +101,20 @@ class HttpRequest {
|
||||
/// read. Existing content of the vector will be cleared.
|
||||
virtual Status SetResultBuffer(std::vector<char>* out_buffer) = 0;
|
||||
|
||||
/// \brief Specifies the buffer for receiving the response body.
|
||||
///
|
||||
/// This method should be used when a caller knows the upper bound of the
|
||||
/// size of the response data. The caller provides a pre-allocated buffer
|
||||
/// and its size. After the Send() method is called, the
|
||||
/// GetResultBufferDirectBytesTransferred() method may be used to learn to the
|
||||
/// number of bytes that were transferred using this method.
|
||||
virtual Status SetResultBufferDirect(char* buffer, size_t size) = 0;
|
||||
|
||||
/// \brief Returns the number of bytes transferred, when using
|
||||
/// SetResultBufferDirect(). This method may only be used when using
|
||||
/// SetResultBufferDirect().
|
||||
virtual size_t GetResultBufferDirectBytesTransferred() = 0;
|
||||
|
||||
/// \brief Returns the response headers of a completed request.
|
||||
///
|
||||
/// If the header is not found, returns an empty string.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -130,12 +131,25 @@ class FakeHttpRequest : public CurlHttpRequest {
|
||||
buffer_ = buffer;
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetResultBufferDirect(char* buffer, size_t size) override {
|
||||
direct_result_buffer_ = buffer;
|
||||
direct_result_buffer_size_ = size;
|
||||
return Status::OK();
|
||||
}
|
||||
size_t GetResultBufferDirectBytesTransferred() override {
|
||||
return direct_result_bytes_transferred_;
|
||||
}
|
||||
Status Send() override {
|
||||
EXPECT_EQ(expected_request_, actual_request())
|
||||
<< "Unexpected HTTP request.";
|
||||
if (buffer_) {
|
||||
buffer_->insert(buffer_->begin(), response_.c_str(),
|
||||
response_.c_str() + response_.size());
|
||||
buffer_->insert(buffer_->begin(), response_.data(),
|
||||
response_.data() + response_.size());
|
||||
} else if (direct_result_buffer_ != nullptr) {
|
||||
size_t bytes_to_copy =
|
||||
std::min<size_t>(direct_result_buffer_size_, response_.size());
|
||||
memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
|
||||
direct_result_bytes_transferred_ += bytes_to_copy;
|
||||
}
|
||||
return response_status_;
|
||||
}
|
||||
@ -178,6 +192,9 @@ class FakeHttpRequest : public CurlHttpRequest {
|
||||
}
|
||||
|
||||
std::vector<char>* buffer_ = nullptr;
|
||||
char* direct_result_buffer_ = nullptr;
|
||||
size_t direct_result_buffer_size_ = 0;
|
||||
size_t direct_result_bytes_transferred_ = 0;
|
||||
string expected_request_;
|
||||
string actual_uri_;
|
||||
string actual_request_;
|
||||
|
@ -136,15 +136,19 @@ void Env::GetLocalTempDirectories(std::vector<string>* list) {
|
||||
// Directories, in order of preference. If we find a dir that
|
||||
// exists, we stop adding other less-preferred dirs
|
||||
const char* candidates[] = {
|
||||
// Non-null only during unittest/regtest
|
||||
getenv("TEST_TMPDIR"),
|
||||
// Non-null only during unittest/regtest
|
||||
getenv("TEST_TMPDIR"),
|
||||
|
||||
// Explicitly-supplied temp dirs
|
||||
getenv("TMPDIR"),
|
||||
getenv("TMP"),
|
||||
// Explicitly-supplied temp dirs
|
||||
getenv("TMPDIR"),
|
||||
getenv("TMP"),
|
||||
|
||||
// If all else fails
|
||||
"/tmp",
|
||||
#if defined(__ANDROID__)
|
||||
"/data/local/tmp",
|
||||
#endif
|
||||
|
||||
// If all else fails
|
||||
"/tmp",
|
||||
};
|
||||
|
||||
for (const char* d : candidates) {
|
||||
|
127
tensorflow/go/genop/api_def_map.go
Normal file
127
tensorflow/go/genop/api_def_map.go
Normal file
@ -0,0 +1,127 @@
|
||||
/*
|
||||
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package internal
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
|
||||
)
|
||||
|
||||
// Encapsulates a collection of API definitions.
|
||||
//
|
||||
// apiDefMap represents a map from operation name to corresponding
|
||||
// ApiDef proto (see
|
||||
// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto
|
||||
// for ApiDef proto definition).
|
||||
type apiDefMap struct {
|
||||
c *C.TF_ApiDefMap
|
||||
}
|
||||
|
||||
// Creates and returns a new apiDefMap instance.
|
||||
//
|
||||
// oplist is and OpList proto instance (see
|
||||
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
|
||||
// for OpList proto definition).
|
||||
|
||||
func newAPIDefMap(oplist *pb.OpList) (*apiDefMap, error) {
|
||||
// Create a buffer containing the serialized OpList.
|
||||
opdefSerialized, err := proto.Marshal(oplist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not serialize OpDef for %s", oplist.String())
|
||||
}
|
||||
data := C.CBytes(opdefSerialized)
|
||||
defer C.free(data)
|
||||
|
||||
opbuf := C.TF_NewBuffer()
|
||||
defer C.TF_DeleteBuffer(opbuf)
|
||||
opbuf.data = data
|
||||
opbuf.length = C.size_t(len(opdefSerialized))
|
||||
|
||||
// Create ApiDefMap.
|
||||
status := C.TF_NewStatus()
|
||||
defer C.TF_DeleteStatus(status)
|
||||
capimap := C.TF_NewApiDefMap(opbuf, status)
|
||||
if C.TF_GetCode(status) != C.TF_OK {
|
||||
return nil, errors.New(C.GoString(C.TF_Message(status)))
|
||||
}
|
||||
apimap := &apiDefMap{capimap}
|
||||
runtime.SetFinalizer(
|
||||
apimap,
|
||||
func(a *apiDefMap) {
|
||||
C.TF_DeleteApiDefMap(a.c)
|
||||
})
|
||||
return apimap, nil
|
||||
}
|
||||
|
||||
// Updates apiDefMap with the overrides specified in `data`.
|
||||
//
|
||||
// data - ApiDef text proto.
|
||||
func (m *apiDefMap) Put(data string) error {
|
||||
cdata := C.CString(data)
|
||||
defer C.free(unsafe.Pointer(cdata))
|
||||
status := C.TF_NewStatus()
|
||||
defer C.TF_DeleteStatus(status)
|
||||
C.TF_ApiDefMapPut(m.c, cdata, C.size_t(len(data)), status)
|
||||
if C.TF_GetCode(status) != C.TF_OK {
|
||||
return errors.New(C.GoString(C.TF_Message(status)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns ApiDef proto instance for the TensorFlow operation
|
||||
// named `opname`.
|
||||
func (m *apiDefMap) Get(opname string) (*pb.ApiDef, error) {
|
||||
cname := C.CString(opname)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
status := C.TF_NewStatus()
|
||||
defer C.TF_DeleteStatus(status)
|
||||
apidefBuf := C.TF_ApiDefMapGet(
|
||||
m.c, cname, C.size_t(len(opname)), status)
|
||||
defer C.TF_DeleteBuffer(apidefBuf)
|
||||
if C.TF_GetCode(status) != C.TF_OK {
|
||||
return nil, errors.New(C.GoString(C.TF_Message(status)))
|
||||
}
|
||||
if apidefBuf == nil {
|
||||
return nil, fmt.Errorf("could not find ApiDef for %s", opname)
|
||||
}
|
||||
|
||||
var (
|
||||
apidef = new(pb.ApiDef)
|
||||
size = int(apidefBuf.length)
|
||||
// A []byte backed by C memory.
|
||||
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
|
||||
data = (*[1 << 30]byte)(unsafe.Pointer(apidefBuf.data))[:size:size]
|
||||
err = proto.Unmarshal(data, apidef)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return apidef, nil
|
||||
}
|
@ -29,12 +29,18 @@ limitations under the License.
|
||||
// encountered.
|
||||
package internal
|
||||
|
||||
// #include "tensorflow/c/c_api.h"
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
"text/template"
|
||||
@ -47,15 +53,23 @@ import (
|
||||
// GenerateFunctionsForRegisteredOps writes a Go source code file to w
|
||||
// containing functions for each TensorFlow operation registered in the address
|
||||
// space of the calling process.
|
||||
func GenerateFunctionsForRegisteredOps(w io.Writer) error {
|
||||
ops, err := registeredOps()
|
||||
// apidefDirs should be a contain of directories containing api_def_*.pbtxt
|
||||
// files to load.
|
||||
func GenerateFunctionsForRegisteredOps(
|
||||
w io.Writer, apidefDirs []string) error {
|
||||
ops, apimap, err := registeredOps()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return generateFunctionsForOps(w, ops)
|
||||
for _, dir := range apidefDirs {
|
||||
if err = updateAPIDefs(apimap, dir); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return generateFunctionsForOps(w, ops, apimap)
|
||||
}
|
||||
|
||||
func registeredOps() (*pb.OpList, error) {
|
||||
func registeredOps() (*pb.OpList, *apiDefMap, error) {
|
||||
buf := C.TF_GetAllOpList()
|
||||
defer C.TF_DeleteBuffer(buf)
|
||||
var (
|
||||
@ -66,10 +80,31 @@ func registeredOps() (*pb.OpList, error) {
|
||||
data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
|
||||
err = proto.Unmarshal(data, list)
|
||||
)
|
||||
return list, err
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
apimap, err := newAPIDefMap(list)
|
||||
return list, apimap, err
|
||||
}
|
||||
|
||||
func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error {
|
||||
func updateAPIDefs(m *apiDefMap, dir string) error {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, file := range files {
|
||||
data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %q: %v", file.Name(), err)
|
||||
}
|
||||
if err = m.Put(string(data)); err != nil {
|
||||
return fmt.Errorf("failed to process %q: %v", file.Name(), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error {
|
||||
thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
|
||||
if err := tmplHeader.Execute(w, thisPackage); err != nil {
|
||||
return err
|
||||
@ -83,14 +118,18 @@ func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error {
|
||||
if blacklist[op.Name] {
|
||||
continue
|
||||
}
|
||||
if err := generateFunctionForOp(w, op); err != nil {
|
||||
apidef, err := apimap.Get(op.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := generateFunctionForOp(w, op, apidef); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateFunctionForOp(w io.Writer, op *pb.OpDef) error {
|
||||
func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error {
|
||||
if strings.HasPrefix(op.Name, "_") { // Internal operation
|
||||
return nil
|
||||
}
|
||||
@ -112,12 +151,16 @@ func generateFunctionForOp(w io.Writer, op *pb.OpDef) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if op.Summary == "" {
|
||||
if apidef.Summary == "" {
|
||||
// Undocumented operation, perhaps a sign of not being ready to
|
||||
// export.
|
||||
return nil
|
||||
}
|
||||
return tmplOp.Execute(w, newTmplArgs(op))
|
||||
tmplArgs, err := newTmplArgs(op, apidef)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return tmplOp.Execute(w, tmplArgs)
|
||||
}
|
||||
|
||||
var (
|
||||
@ -172,7 +215,7 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in
|
||||
type {{.Op.Name}}Attr func(optionalAttr)
|
||||
|
||||
{{range .OptionalAttrs}}
|
||||
// {{$.Op.Name}}{{CamelCase .Name}} sets the optional {{.Name}} attribute to value.
|
||||
// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
|
||||
{{- if .Description}}
|
||||
//
|
||||
// value: {{MakeComment .Description}}
|
||||
@ -180,9 +223,9 @@ type {{.Op.Name}}Attr func(optionalAttr)
|
||||
// If not specified, defaults to {{StripLeadingColon .DefaultValue}}
|
||||
{{- if .HasMinimum}}
|
||||
//
|
||||
// {{if IsListAttr .}}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
|
||||
// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
|
||||
{{- end}}
|
||||
func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
|
||||
func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
|
||||
return func(m optionalAttr) {
|
||||
m[{{printf "%q" .Name}}] = value
|
||||
}
|
||||
@ -192,14 +235,14 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
|
||||
|
||||
{{- /* Create a godoc friendly comment. */ -}}
|
||||
|
||||
// {{MakeComment .Op.Summary}}
|
||||
// {{MakeComment .APIDef.Summary}}
|
||||
|
||||
{{- with .Op.Deprecation}}
|
||||
//
|
||||
// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
|
||||
{{- end -}}
|
||||
|
||||
{{- with .Op.Description}}
|
||||
{{- with .APIDef.Description}}
|
||||
//
|
||||
// {{MakeComment .}}
|
||||
{{- end -}}
|
||||
@ -207,11 +250,11 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
|
||||
{{- if .DescribeArguments}}
|
||||
//
|
||||
// Arguments:
|
||||
{{- range .Op.InputArg}}
|
||||
// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}}
|
||||
{{- range .InArgsReordered}}
|
||||
// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
|
||||
{{- end -}}
|
||||
{{- range .RequiredAttrs}}
|
||||
// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}}
|
||||
// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
@ -221,12 +264,12 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
|
||||
{{- else }}
|
||||
{{- if .DescribeOutputs}}
|
||||
//
|
||||
{{- if ((len .Op.OutputArg) eq 1) }}
|
||||
// Returns {{range .Op.OutputArg}}{{MakeComment .Description}}{{end}}
|
||||
{{- if ((len .OutArgs) eq 1) }}
|
||||
// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
|
||||
{{- else }}
|
||||
// Returns:
|
||||
{{- range .Op.OutputArg}}
|
||||
// {{Identifier .Name}}{{if .Description}}: {{MakeComment .Description}}{{end}}
|
||||
{{- range .OutArgs}}
|
||||
// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
@ -247,15 +290,15 @@ func {{.Op.Name}}
|
||||
*/ -}}
|
||||
|
||||
(scope *Scope
|
||||
{{- range $i, $a := .Op.InputArg}}, {{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}
|
||||
{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.Name}} {{GoType $a.Type}}{{end -}}
|
||||
{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
|
||||
{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
|
||||
{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
|
||||
)
|
||||
|
||||
{{- /* Construct outputs: len(OpDef.OutputArg) or a *tf.Operation */ -}}
|
||||
{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
|
||||
|
||||
{{if .Op.OutputArg -}}
|
||||
({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}})
|
||||
{{if .OutArgs -}}
|
||||
({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
|
||||
{{- else -}}
|
||||
(o *tf.Operation)
|
||||
{{- end }} {
|
||||
@ -263,7 +306,7 @@ func {{.Op.Name}}
|
||||
return
|
||||
}
|
||||
{{if .HasAttrs -}}
|
||||
attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}}
|
||||
attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
|
||||
{{if .OptionalAttrs -}}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
@ -272,16 +315,16 @@ func {{.Op.Name}}
|
||||
{{end -}}
|
||||
opspec := tf.OpSpec{
|
||||
Type: {{printf "%q" .Op.Name}},
|
||||
{{if .Op.InputArg -}}
|
||||
{{if .InArgs -}}
|
||||
Input: []tf.Input{
|
||||
{{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}}
|
||||
{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
|
||||
},
|
||||
{{- end}}
|
||||
{{- if .HasAttrs}}
|
||||
Attrs: attrs,
|
||||
{{- end}}
|
||||
}
|
||||
{{- if .Op.OutputArg}}
|
||||
{{- if .OutArgs}}
|
||||
{{- if .HasListOutput}}
|
||||
op := scope.AddOperation(opspec)
|
||||
if scope.Err() != nil {
|
||||
@ -289,43 +332,105 @@ func {{.Op.Name}}
|
||||
}
|
||||
var idx int
|
||||
var err error
|
||||
{{- range $i, $a := .Op.OutputArg}}
|
||||
{{- if IsListArg $a}}
|
||||
if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
|
||||
{{- range $i, $a := .OutArgs}}
|
||||
{{- if $a.IsListArg}}
|
||||
if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
|
||||
scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
|
||||
return
|
||||
}
|
||||
{{- else }}
|
||||
{{Identifier .Name}} = op.Output(idx)
|
||||
{{Identifier .RenameTo}} = op.Output(idx)
|
||||
{{- end }}{{- /* if IsListArg */}}
|
||||
{{- end }}{{- /* range .Op.OutputArg */}}
|
||||
return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier .Name}}{{end}}
|
||||
{{- end }}{{- /* range .OutArgs */}}
|
||||
return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
|
||||
{{- else }}
|
||||
op := scope.AddOperation(opspec)
|
||||
return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
|
||||
return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
|
||||
{{- end }}{{- /* if .HasListOutput */}}
|
||||
{{- else }}
|
||||
return scope.AddOperation(opspec)
|
||||
{{- end }}{{- /* if .Op.OutputArg */}}
|
||||
{{- end }}{{- /* if .OutArgs */}}
|
||||
}
|
||||
`))
|
||||
)
|
||||
|
||||
type attrWrapper struct {
|
||||
op *pb.OpDef_AttrDef
|
||||
api *pb.ApiDef_Attr
|
||||
}
|
||||
|
||||
func (a *attrWrapper) Name() string { return a.api.Name }
|
||||
func (a *attrWrapper) RenameTo() string { return a.api.RenameTo }
|
||||
func (a *attrWrapper) Description() string { return a.api.Description }
|
||||
func (a *attrWrapper) Type() string { return a.op.Type }
|
||||
func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) }
|
||||
func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum }
|
||||
func (a *attrWrapper) Minimum() int64 { return a.op.Minimum }
|
||||
func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
|
||||
|
||||
type argWrapper struct {
|
||||
op *pb.OpDef_ArgDef
|
||||
api *pb.ApiDef_Arg
|
||||
}
|
||||
|
||||
func (a *argWrapper) Name() string { return a.api.Name }
|
||||
func (a *argWrapper) RenameTo() string { return a.api.RenameTo }
|
||||
func (a *argWrapper) Description() string { return a.api.Description }
|
||||
func (a *argWrapper) IsListArg() bool { return isListArg(a.op) }
|
||||
|
||||
type tmplArgs struct {
|
||||
Op *pb.OpDef
|
||||
Op *pb.OpDef
|
||||
APIDef *pb.ApiDef
|
||||
// Op.Attr is split into two categories
|
||||
// (1) Required: These must be specified by the client and are thus
|
||||
// included in the function signature.
|
||||
// (2) Optional: These need not be specified (as they have default
|
||||
// values) and thus do not appear in the function signature.
|
||||
RequiredAttrs []*pb.OpDef_AttrDef
|
||||
OptionalAttrs []*pb.OpDef_AttrDef
|
||||
RequiredAttrs []*attrWrapper
|
||||
OptionalAttrs []*attrWrapper
|
||||
InArgs []*argWrapper
|
||||
// Input arguments ordered based on arg_order field of ApiDef.
|
||||
InArgsReordered []*argWrapper
|
||||
OutArgs []*argWrapper
|
||||
}
|
||||
|
||||
func newTmplArgs(op *pb.OpDef) *tmplArgs {
|
||||
ret := tmplArgs{Op: op}
|
||||
func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) {
|
||||
ret := tmplArgs{Op: op, APIDef: apidef}
|
||||
|
||||
// Setup InArgs field
|
||||
for i, in := range op.InputArg {
|
||||
argCombined := argWrapper{op: in, api: apidef.InArg[i]}
|
||||
ret.InArgs = append(ret.InArgs, &argCombined)
|
||||
}
|
||||
|
||||
// Setup OutArgs field
|
||||
for i, out := range op.OutputArg {
|
||||
argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
|
||||
ret.OutArgs = append(ret.OutArgs, &argCombined)
|
||||
}
|
||||
|
||||
// Setup InArgsReordered field
|
||||
for _, argName := range apidef.ArgOrder {
|
||||
// Find the argument in op.InputArg
|
||||
argIndex := -1
|
||||
for i, in := range op.InputArg {
|
||||
if in.Name == argName {
|
||||
argIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if argIndex == -1 {
|
||||
return nil, fmt.Errorf(
|
||||
"couldn't find argument %s in ApiDef for op %s",
|
||||
argName, op.Name)
|
||||
}
|
||||
argCombined := argWrapper{
|
||||
op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
|
||||
ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
|
||||
}
|
||||
|
||||
if len(op.Attr) == 0 {
|
||||
return &ret
|
||||
return &ret, nil
|
||||
}
|
||||
// Attributes related to the InputArg's type are inferred automatically
|
||||
// and are not exposed to the client.
|
||||
@ -341,28 +446,29 @@ func newTmplArgs(op *pb.OpDef) *tmplArgs {
|
||||
inferred[in.NumberAttr] = true
|
||||
}
|
||||
}
|
||||
for _, attr := range op.Attr {
|
||||
for i, attr := range op.Attr {
|
||||
if inferred[attr.Name] {
|
||||
continue
|
||||
}
|
||||
attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
|
||||
if attr.DefaultValue == nil {
|
||||
ret.RequiredAttrs = append(ret.RequiredAttrs, attr)
|
||||
ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
|
||||
} else {
|
||||
ret.OptionalAttrs = append(ret.OptionalAttrs, attr)
|
||||
ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
|
||||
}
|
||||
}
|
||||
return &ret
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
|
||||
func (a *tmplArgs) DescribeArguments() bool {
|
||||
for _, arg := range a.Op.InputArg {
|
||||
if arg.Description != "" {
|
||||
for _, arg := range a.InArgs {
|
||||
if arg.Description() != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, attr := range a.RequiredAttrs {
|
||||
if attr.Description != "" {
|
||||
if attr.Description() != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -370,16 +476,16 @@ func (a *tmplArgs) DescribeArguments() bool {
|
||||
|
||||
}
|
||||
func (a *tmplArgs) DescribeOutputs() bool {
|
||||
for _, arg := range a.Op.OutputArg {
|
||||
if arg.Description != "" {
|
||||
for _, arg := range a.OutArgs {
|
||||
if arg.Description() != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (a *tmplArgs) HasListOutput() bool {
|
||||
for _, arg := range a.Op.OutputArg {
|
||||
if isListArg(arg) {
|
||||
for _, arg := range a.OutArgs {
|
||||
if arg.IsListArg() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -25,19 +25,44 @@ import (
|
||||
pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
|
||||
)
|
||||
|
||||
// Creates an ApiDef based on opdef and applies overrides
|
||||
// from apidefText (ApiDef text proto).
|
||||
func GetAPIDef(t *testing.T, opdef *pb.OpDef, apidefText string) *pb.ApiDef {
|
||||
opdefList := &pb.OpList{Op: []*pb.OpDef{opdef}}
|
||||
apimap, err := newAPIDefMap(opdefList)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = apimap.Put(apidefText)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
apidef, err := apimap.Get(opdef.Name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return apidef
|
||||
}
|
||||
|
||||
func TestGenerateOp(t *testing.T) {
|
||||
// TestGenerateOp validates the generated source code for an op.
|
||||
// The OpDef for the test cases are simplified forms of real ops.
|
||||
testdata := []struct {
|
||||
tag string
|
||||
opdef string
|
||||
apidef string
|
||||
wanted string
|
||||
}{
|
||||
{
|
||||
tag: "NoOp",
|
||||
opdef: `
|
||||
name: "NoOp"
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "NoOp"
|
||||
summary: "No. Op."
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// No. Op.
|
||||
@ -80,8 +105,13 @@ attr: <
|
||||
>
|
||||
>
|
||||
>
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "Add"
|
||||
summary: "Returns x + y element-wise."
|
||||
description: "Blah blah",
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// Returns x + y element-wise.
|
||||
@ -122,7 +152,12 @@ attr: <
|
||||
name: "DstT"
|
||||
type: "type"
|
||||
>
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "Cast"
|
||||
summary: "Cast x of type SrcT to y of DstT."
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// Cast x of type SrcT to y of DstT.
|
||||
@ -149,12 +184,10 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) {
|
||||
name: "DecodeJpeg"
|
||||
input_arg: <
|
||||
name: "contents"
|
||||
description: "0-D. The JPEG-encoded image."
|
||||
type: DT_STRING
|
||||
>
|
||||
output_arg: <
|
||||
name: "image"
|
||||
description: "3-D with shape [height, width, channels]"
|
||||
type: DT_UINT8
|
||||
>
|
||||
attr: <
|
||||
@ -163,7 +196,6 @@ attr: <
|
||||
default_value: <
|
||||
i: 0
|
||||
>
|
||||
description: "Number of color channels for the decoded image."
|
||||
>
|
||||
attr: <
|
||||
name: "fancy_upscaling"
|
||||
@ -171,7 +203,6 @@ attr: <
|
||||
default_value: <
|
||||
b: true
|
||||
>
|
||||
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
|
||||
>
|
||||
attr: <
|
||||
name: "acceptable_fraction"
|
||||
@ -179,10 +210,34 @@ attr: <
|
||||
default_value: <
|
||||
f: 1
|
||||
>
|
||||
>
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "DecodeJpeg"
|
||||
in_arg: <
|
||||
name: "contents"
|
||||
description: "0-D. The JPEG-encoded image."
|
||||
>
|
||||
out_arg: <
|
||||
name: "image"
|
||||
description: "3-D with shape [height, width, channels]"
|
||||
>
|
||||
attr: <
|
||||
name: "channels"
|
||||
description: "Number of color channels for the decoded image."
|
||||
>
|
||||
attr: <
|
||||
name: "fancy_upscaling"
|
||||
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
|
||||
>
|
||||
attr: <
|
||||
name: "acceptable_fraction"
|
||||
description: "The minimum required fraction of lines before a truncated\ninput is accepted."
|
||||
>
|
||||
summary: "Decode a JPEG-encoded image to a uint8 tensor."
|
||||
description: "Norna dorna fjord\nkajorna\nhahaha"
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// DecodeJpegAttr is an optional argument to DecodeJpeg.
|
||||
@ -270,7 +325,12 @@ attr: <
|
||||
name: "T"
|
||||
type: "type"
|
||||
>
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "TwoOutputs"
|
||||
summary: "Op that produces multiple outputs"
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// Op that produces multiple outputs
|
||||
@ -326,8 +386,13 @@ attr: <
|
||||
>
|
||||
>
|
||||
>
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "ShapeN"
|
||||
summary: "Returns shape of tensors."
|
||||
description: "Some description here."
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// ShapeNAttr is an optional argument to ShapeN.
|
||||
@ -371,6 +436,102 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
|
||||
}
|
||||
return output
|
||||
}
|
||||
`,
|
||||
},
|
||||
{
|
||||
tag: "ApiDefOverrides",
|
||||
opdef: `
|
||||
name: "TestOp"
|
||||
input_arg: <
|
||||
name: "a"
|
||||
type: DT_STRING
|
||||
>
|
||||
input_arg: <
|
||||
name: "b"
|
||||
type: DT_STRING
|
||||
>
|
||||
output_arg: <
|
||||
name: "c"
|
||||
type: DT_UINT8
|
||||
>
|
||||
attr: <
|
||||
name: "d"
|
||||
type: "int"
|
||||
default_value: <
|
||||
i: 0
|
||||
>
|
||||
>
|
||||
`,
|
||||
apidef: `
|
||||
op: <
|
||||
graph_op_name: "TestOp"
|
||||
in_arg: <
|
||||
name: "a"
|
||||
rename_to: "aa"
|
||||
description: "Description for aa."
|
||||
>
|
||||
in_arg: <
|
||||
name: "b"
|
||||
rename_to: "bb"
|
||||
description: "Description for bb."
|
||||
>
|
||||
arg_order: "b"
|
||||
arg_order: "a"
|
||||
out_arg: <
|
||||
name: "c"
|
||||
rename_to: "cc"
|
||||
description: "Description for cc."
|
||||
>
|
||||
attr: <
|
||||
name: "d"
|
||||
rename_to: "dd"
|
||||
description: "Description for dd."
|
||||
>
|
||||
summary: "Summary for TestOp."
|
||||
description: "Description for TestOp."
|
||||
>
|
||||
`,
|
||||
wanted: `
|
||||
// TestOpAttr is an optional argument to TestOp.
|
||||
type TestOpAttr func(optionalAttr)
|
||||
|
||||
// TestOpDd sets the optional dd attribute to value.
|
||||
//
|
||||
// value: Description for dd.
|
||||
// If not specified, defaults to 0
|
||||
func TestOpDd(value int64) TestOpAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["d"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Summary for TestOp.
|
||||
//
|
||||
// Description for TestOp.
|
||||
//
|
||||
// Arguments:
|
||||
// bb: Description for bb.
|
||||
// aa: Description for aa.
|
||||
//
|
||||
// Returns Description for cc.
|
||||
func TestOp(scope *Scope, bb tf.Output, aa tf.Output, optional ...TestOpAttr) (cc tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "TestOp",
|
||||
Input: []tf.Input{
|
||||
aa, bb,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
`,
|
||||
},
|
||||
}
|
||||
@ -378,11 +539,13 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
|
||||
for _, test := range testdata {
|
||||
t.Run(test.tag, func(t *testing.T) {
|
||||
var opdef pb.OpDef
|
||||
var apidef *pb.ApiDef
|
||||
var buf bytes.Buffer
|
||||
if err := proto.UnmarshalText(test.opdef, &opdef); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := generateFunctionForOp(&buf, &opdef); err != nil {
|
||||
apidef = GetAPIDef(t, &opdef, test.apidef)
|
||||
if err := generateFunctionForOp(&buf, &opdef, apidef); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := format.Source(buf.Bytes())
|
||||
|
@ -27,15 +27,17 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/tensorflow/tensorflow/tensorflow/go/genop/internal"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
filename = flag.String("outfile", "", "File to write generated source code to.")
|
||||
header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty")
|
||||
buf bytes.Buffer
|
||||
filename = flag.String("outfile", "", "File to write generated source code to.")
|
||||
header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty")
|
||||
apiDefDirs = flag.String("api_def_dirs", "", "Comma-separated directories containing api_def_*.pbtxt files.")
|
||||
buf bytes.Buffer
|
||||
)
|
||||
flag.Parse()
|
||||
if *filename == "" {
|
||||
@ -51,7 +53,13 @@ func main() {
|
||||
}
|
||||
os.MkdirAll(filepath.Dir(*filename), 0755)
|
||||
|
||||
if err := internal.GenerateFunctionsForRegisteredOps(&buf); err != nil {
|
||||
apiDefDirsList := []string{}
|
||||
if len(*apiDefDirs) > 0 {
|
||||
apiDefDirsList = strings.Split(*apiDefDirs, ",")
|
||||
}
|
||||
|
||||
if err := internal.GenerateFunctionsForRegisteredOps(
|
||||
&buf, apiDefDirsList); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
formatted, err := format.Source(buf.Bytes())
|
||||
|
@ -52,6 +52,12 @@ py_library(
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "common",
|
||||
srcs = ["lib/common.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "debug_graphs",
|
||||
srcs = ["lib/debug_graphs.py"],
|
||||
@ -117,6 +123,7 @@ py_library(
|
||||
srcs = ["lib/source_remote.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":common",
|
||||
":debug_service_pb2_grpc",
|
||||
"//tensorflow/core/debug:debug_service_proto_py",
|
||||
"//tensorflow/python/profiler:tfprof_logger",
|
||||
@ -193,6 +200,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":command_parser",
|
||||
":common",
|
||||
":debugger_cli_common",
|
||||
":tensor_format",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
@ -334,7 +342,11 @@ py_library(
|
||||
name = "grpc_wrapper",
|
||||
srcs = ["wrappers/grpc_wrapper.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [":framework"],
|
||||
deps = [
|
||||
":common",
|
||||
":framework",
|
||||
":source_remote",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -345,6 +357,7 @@ py_library(
|
||||
":analyzer_cli",
|
||||
":cli_shared",
|
||||
":command_parser",
|
||||
":common",
|
||||
":debug_data",
|
||||
":debugger_cli_common",
|
||||
":framework",
|
||||
@ -439,6 +452,20 @@ py_binary(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "common_test",
|
||||
size = "small",
|
||||
srcs = ["lib/common_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":common",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "debug_graphs_test",
|
||||
size = "small",
|
||||
|
@ -55,7 +55,8 @@ def no_rewrite_session_config():
|
||||
rewriter_config = rewriter_config_pb2.RewriterConfig(
|
||||
disable_model_pruning=True,
|
||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||
|
@ -25,6 +25,7 @@ import six
|
||||
from tensorflow.python.debug.cli import command_parser
|
||||
from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import tensor_format
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
@ -214,51 +215,6 @@ def error(msg):
|
||||
RL("ERROR: " + msg, COLOR_RED)])
|
||||
|
||||
|
||||
def get_graph_element_name(elem):
|
||||
"""Obtain the name or string representation of a graph element.
|
||||
|
||||
If the graph element has the attribute "name", return name. Otherwise, return
|
||||
a __str__ representation of the graph element. Certain graph elements, such as
|
||||
`SparseTensor`s, do not have the attribute "name".
|
||||
|
||||
Args:
|
||||
elem: The graph element in question.
|
||||
|
||||
Returns:
|
||||
If the attribute 'name' is available, return the name. Otherwise, return
|
||||
str(fetch).
|
||||
"""
|
||||
|
||||
return elem.name if hasattr(elem, "name") else str(elem)
|
||||
|
||||
|
||||
def _get_fetch_names(fetches):
|
||||
"""Get a flattened list of the names in run() call fetches.
|
||||
|
||||
Args:
|
||||
fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an
|
||||
Operation or a Variable. It may also be nested lists, tuples or
|
||||
dicts. See doc of `Session.run()` for more details.
|
||||
|
||||
Returns:
|
||||
(list of str) A flattened list of fetch names from `fetches`.
|
||||
"""
|
||||
|
||||
lines = []
|
||||
if isinstance(fetches, (list, tuple)):
|
||||
for fetch in fetches:
|
||||
lines.extend(_get_fetch_names(fetch))
|
||||
elif isinstance(fetches, dict):
|
||||
for key in fetches:
|
||||
lines.extend(_get_fetch_names(fetches[key]))
|
||||
else:
|
||||
# This ought to be a Tensor, an Operation or a Variable, for which the name
|
||||
# attribute should be available. (Bottom-out condition of the recursion.)
|
||||
lines.append(get_graph_element_name(fetches))
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _recommend_command(command, description, indent=2, create_link=False):
|
||||
"""Generate a RichTextLines object that describes a recommended command.
|
||||
|
||||
@ -327,14 +283,14 @@ def get_run_start_intro(run_call_count,
|
||||
(RichTextLines) Formatted intro message about the `Session.run()` call.
|
||||
"""
|
||||
|
||||
fetch_lines = _get_fetch_names(fetches)
|
||||
fetch_lines = common.get_flattened_names(fetches)
|
||||
|
||||
if not feed_dict:
|
||||
feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")]
|
||||
else:
|
||||
feed_dict_lines = []
|
||||
for feed_key in feed_dict:
|
||||
feed_key_name = get_graph_element_name(feed_key)
|
||||
feed_key_name = common.get_graph_element_name(feed_key)
|
||||
feed_dict_line = debugger_cli_common.RichLine(" ")
|
||||
feed_dict_line += debugger_cli_common.RichLine(
|
||||
feed_key_name,
|
||||
@ -446,10 +402,10 @@ def get_run_short_description(run_call_count,
|
||||
description = "run #%d: " % run_call_count
|
||||
|
||||
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
|
||||
description += "1 fetch (%s); " % get_graph_element_name(fetches)
|
||||
description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
|
||||
else:
|
||||
# Could be (nested) list, tuple, dict or namedtuple.
|
||||
num_fetches = len(_get_fetch_names(fetches))
|
||||
num_fetches = len(common.get_flattened_names(fetches))
|
||||
if num_fetches > 1:
|
||||
description += "%d fetches; " % num_fetches
|
||||
else:
|
||||
|
87
tensorflow/python/debug/lib/common.py
Normal file
87
tensorflow/python/debug/lib/common.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common values and methods for TensorFlow Debugger."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
GRPC_URL_PREFIX = "grpc://"
|
||||
|
||||
# A key for a Session.run() call.
|
||||
RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"])
|
||||
|
||||
|
||||
def get_graph_element_name(elem):
|
||||
"""Obtain the name or string representation of a graph element.
|
||||
|
||||
If the graph element has the attribute "name", return name. Otherwise, return
|
||||
a __str__ representation of the graph element. Certain graph elements, such as
|
||||
`SparseTensor`s, do not have the attribute "name".
|
||||
|
||||
Args:
|
||||
elem: The graph element in question.
|
||||
|
||||
Returns:
|
||||
If the attribute 'name' is available, return the name. Otherwise, return
|
||||
str(fetch).
|
||||
"""
|
||||
|
||||
return elem.name if hasattr(elem, "name") else str(elem)
|
||||
|
||||
|
||||
def get_flattened_names(feeds_or_fetches):
|
||||
"""Get a flattened list of the names in run() call feeds or fetches.
|
||||
|
||||
Args:
|
||||
feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe
|
||||
a Tensor, an Operation or a Variable. It may also be nested lists, tuples
|
||||
or dicts. See doc of `Session.run()` for more details.
|
||||
|
||||
Returns:
|
||||
(list of str) A flattened list of fetch names from `feeds_or_fetches`.
|
||||
"""
|
||||
|
||||
lines = []
|
||||
if isinstance(feeds_or_fetches, (list, tuple)):
|
||||
for item in feeds_or_fetches:
|
||||
lines.extend(get_flattened_names(item))
|
||||
elif isinstance(feeds_or_fetches, dict):
|
||||
for key in feeds_or_fetches:
|
||||
lines.extend(get_flattened_names(feeds_or_fetches[key]))
|
||||
else:
|
||||
# This ought to be a Tensor, an Operation or a Variable, for which the name
|
||||
# attribute should be available. (Bottom-out condition of the recursion.)
|
||||
lines.append(get_graph_element_name(feeds_or_fetches))
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def get_run_key(feed_dict, fetches):
|
||||
"""Summarize the names of feeds and fetches as a RunKey JSON string.
|
||||
|
||||
Args:
|
||||
feed_dict: The feed_dict given to the `Session.run()` call.
|
||||
fetches: The fetches from the `Session.run()` call.
|
||||
|
||||
Returns:
|
||||
A JSON Array consisting of two items. They first items is a flattened
|
||||
Array of the names of the feeds. The second item is a flattened Array of
|
||||
the names of the fetches.
|
||||
"""
|
||||
return json.dumps(RunKey(get_flattened_names(feed_dict),
|
||||
get_flattened_names(fetches)))
|
59
tensorflow/python/debug/lib/common_test.py
Normal file
59
tensorflow/python/debug/lib/common_test.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Unit tests for common values and methods of TensorFlow Debugger."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class CommonTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testOnFeedOneFetch(self):
|
||||
a = constant_op.constant(10.0, name="a")
|
||||
b = constant_op.constant(20.0, name="b")
|
||||
run_key = common.get_run_key({"a": a}, [b])
|
||||
loaded = json.loads(run_key)
|
||||
self.assertItemsEqual(["a:0"], loaded[0])
|
||||
self.assertItemsEqual(["b:0"], loaded[1])
|
||||
|
||||
def testGetRunKeyFlat(self):
|
||||
a = constant_op.constant(10.0, name="a")
|
||||
b = constant_op.constant(20.0, name="b")
|
||||
run_key = common.get_run_key({"a": a}, [a, b])
|
||||
loaded = json.loads(run_key)
|
||||
self.assertItemsEqual(["a:0"], loaded[0])
|
||||
self.assertItemsEqual(["a:0", "b:0"], loaded[1])
|
||||
|
||||
def testGetRunKeyNestedFetches(self):
|
||||
a = constant_op.constant(10.0, name="a")
|
||||
b = constant_op.constant(20.0, name="b")
|
||||
c = constant_op.constant(30.0, name="c")
|
||||
d = constant_op.constant(30.0, name="d")
|
||||
run_key = common.get_run_key(
|
||||
{}, {"set1": [a, b], "set2": {"c": c, "d": d}})
|
||||
loaded = json.loads(run_key)
|
||||
self.assertItemsEqual([], loaded[0])
|
||||
self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -310,7 +310,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
|
||||
op_log_proto.id_to_string)
|
||||
raise ValueError(
|
||||
"Op '%s' does not exist in the tracebacks received by the debug "
|
||||
"server.")
|
||||
"server." % op_name)
|
||||
|
||||
def query_origin_stack(self):
|
||||
"""Query the stack of the origin of the execution call.
|
||||
@ -348,6 +348,9 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
|
||||
Raises:
|
||||
ValueError: If no source file is found at the given file_path.
|
||||
"""
|
||||
if not self._source_files:
|
||||
raise ValueError(
|
||||
"This debug server has not received any source file contents yet.")
|
||||
for source_file_proto in self._source_files.source_files:
|
||||
if source_file_proto.file_path == file_path:
|
||||
return source_file_proto.lines[lineno - 1]
|
||||
|
@ -248,7 +248,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
self.assertEqual(
|
||||
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
|
||||
|
||||
def testTensorBoardDebugHooWorks(self):
|
||||
def testTensorBoardDebugHookWorks(self):
|
||||
u = variables.Variable(2.1, name="u")
|
||||
v = variables.Variable(20.0, name="v")
|
||||
w = math_ops.multiply(u, v, name="w")
|
||||
@ -261,8 +261,37 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
["localhost:%d" % self._server_port])
|
||||
sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
|
||||
|
||||
# Activate watch point on some a tensor before calling sess.run().
|
||||
self._server.request_watch("u/read", 0, "DebugIdentity")
|
||||
self.assertAllClose(42.0, sess.run(w))
|
||||
|
||||
# self.assertAllClose(42.0, sess.run(w))
|
||||
dump = debug_data.DebugDumpDir(self._dump_root)
|
||||
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
|
||||
|
||||
# Check that the server has received the stack trace.
|
||||
self.assertTrue(self._server.query_op_traceback("u"))
|
||||
self.assertTrue(self._server.query_op_traceback("u/read"))
|
||||
self.assertTrue(self._server.query_op_traceback("v"))
|
||||
self.assertTrue(self._server.query_op_traceback("v/read"))
|
||||
self.assertTrue(self._server.query_op_traceback("w"))
|
||||
|
||||
# Check that the server has received the python file content.
|
||||
# Query an arbitrary line to make sure that is the case.
|
||||
with open(__file__, "rt") as this_source_file:
|
||||
first_line = this_source_file.readline().strip()
|
||||
self.assertEqual(
|
||||
first_line, self._server.query_source_file_line(__file__, 1))
|
||||
|
||||
self._server.clear_data()
|
||||
# Call sess.run() again, and verify that this time the traceback and source
|
||||
# code is not sent, because the graph version is not newer.
|
||||
self.assertAllClose(42.0, sess.run(w))
|
||||
with self.assertRaises(ValueError):
|
||||
self._server.query_op_traceback("delta_1")
|
||||
with self.assertRaises(ValueError):
|
||||
self._server.query_source_file_line(__file__, 1)
|
||||
|
||||
def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
|
||||
hooks.GrpcDebugHook(["grpc://foo:42424"])
|
||||
hooks.GrpcDebugHook(["foo:42424"])
|
||||
@ -748,6 +777,28 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
||||
# to disable the breakpoint at delta:0:DebugIdentity.
|
||||
self.assertSetEqual(set(), self._server_1.breakpoints)
|
||||
|
||||
if i == 0:
|
||||
# Check that the server has received the stack trace.
|
||||
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
|
||||
self.assertTrue(self._server_1.query_op_traceback("delta_2"))
|
||||
self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
|
||||
self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
|
||||
# Check that the server has received the python file content.
|
||||
# Query an arbitrary line to make sure that is the case.
|
||||
with open(__file__, "rt") as this_source_file:
|
||||
first_line = this_source_file.readline().strip()
|
||||
self.assertEqual(
|
||||
first_line, self._server_1.query_source_file_line(__file__, 1))
|
||||
else:
|
||||
# In later Session.run() calls, the traceback shouldn't have been sent
|
||||
# because it is already sent in the 1st call. So calling
|
||||
# query_op_traceback() should lead to an exception, because the test
|
||||
# debug server clears the data at the beginning of every iteration.
|
||||
with self.assertRaises(ValueError):
|
||||
self._server_1.query_op_traceback("delta_1")
|
||||
with self.assertRaises(ValueError):
|
||||
self._server_1.query_source_file_line(__file__, 1)
|
||||
|
||||
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
|
||||
with session.Session() as sess:
|
||||
v = variables.Variable(50.0, name="v")
|
||||
|
@ -24,6 +24,7 @@ import grpc
|
||||
|
||||
from tensorflow.core.debug import debug_service_pb2
|
||||
from tensorflow.core.protobuf import debug_pb2
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
||||
from tensorflow.python.debug.lib import source_utils
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -130,6 +131,11 @@ def _send_call_tracebacks(destinations,
|
||||
"""
|
||||
if not isinstance(destinations, list):
|
||||
destinations = [destinations]
|
||||
# Strip grpc:// prefix, if any is present.
|
||||
destinations = [
|
||||
dest[len(common.GRPC_URL_PREFIX):]
|
||||
if dest.startswith(common.GRPC_URL_PREFIX) else dest
|
||||
for dest in destinations]
|
||||
|
||||
call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
|
||||
if is_eager_execution
|
||||
|
@ -17,15 +17,55 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Google-internal import(s).
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.wrappers import framework
|
||||
|
||||
|
||||
def publish_traceback(debug_server_urls,
|
||||
graph,
|
||||
feed_dict,
|
||||
fetches,
|
||||
old_graph_version):
|
||||
"""Publish traceback and source code if graph version is new.
|
||||
|
||||
`graph.version` is compared with `old_graph_version`. If the former is higher
|
||||
(i.e., newer), the graph traceback and the associated source code is sent to
|
||||
the debug server at the specified gRPC URLs.
|
||||
|
||||
Args:
|
||||
debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
|
||||
debug server URLs.
|
||||
graph: A Python `tf.Graph` object.
|
||||
feed_dict: Feed dictionary given to the `Session.run()` call.
|
||||
fetches: Fetches from the `Session.run()` call.
|
||||
old_graph_version: Old graph version to compare to.
|
||||
|
||||
Returns:
|
||||
If `graph.version > old_graph_version`, the new graph version as an `int`.
|
||||
Else, the `old_graph_version` is returned.
|
||||
"""
|
||||
# TODO(cais): Consider moving this back to the top, after grpc becomes a
|
||||
# pip dependency of tensorflow or tf_debug.
|
||||
# pylint:disable=g-import-not-at-top
|
||||
from tensorflow.python.debug.lib import source_remote
|
||||
# pylint:enable=g-import-not-at-top
|
||||
if graph.version > old_graph_version:
|
||||
run_key = common.get_run_key(feed_dict, fetches)
|
||||
source_remote.send_graph_tracebacks(
|
||||
debug_server_urls, run_key, traceback.extract_stack(), graph,
|
||||
send_source=True)
|
||||
return graph.version
|
||||
else:
|
||||
return old_graph_version
|
||||
|
||||
|
||||
class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
|
||||
"""Debug Session wrapper that send debug data to gRPC stream(s)."""
|
||||
|
||||
_GRPC_URL_PREFIX = "grpc://"
|
||||
|
||||
def __init__(self,
|
||||
sess,
|
||||
grpc_debug_server_addresses,
|
||||
@ -94,8 +134,8 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
|
||||
return self._grpc_debug_server_urls
|
||||
|
||||
def _normalize_grpc_url(self, address):
|
||||
return (self._GRPC_URL_PREFIX + address
|
||||
if not address.startswith(self._GRPC_URL_PREFIX) else address)
|
||||
return (common.GRPC_URL_PREFIX + address
|
||||
if not address.startswith(common.GRPC_URL_PREFIX) else address)
|
||||
|
||||
|
||||
class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
|
||||
@ -126,3 +166,25 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
|
||||
watch_fn=_gated_grpc_watch_fn,
|
||||
thread_name_filter=thread_name_filter,
|
||||
log_usage=log_usage)
|
||||
|
||||
# Keeps track of the latest version of Python graph object that has been
|
||||
# sent to the debug servers.
|
||||
self._sent_graph_version = -sys.maxint
|
||||
|
||||
def run(self,
|
||||
fetches,
|
||||
feed_dict=None,
|
||||
options=None,
|
||||
run_metadata=None,
|
||||
callable_runner=None,
|
||||
callable_runner_args=None):
|
||||
self._sent_graph_version = publish_traceback(
|
||||
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
|
||||
self._sent_graph_version)
|
||||
return super(TensorBoardDebugWrapperSession, self).run(
|
||||
fetches,
|
||||
feed_dict=feed_dict,
|
||||
options=options,
|
||||
run_metadata=run_metadata,
|
||||
callable_runner=callable_runner,
|
||||
callable_runner_args=callable_runner_args)
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.debug.lib import debug_utils
|
||||
from tensorflow.python.debug.lib import stepper
|
||||
@ -331,3 +333,13 @@ class TensorBoardDebugHook(GrpcDebugHook):
|
||||
watch_fn=_gated_grpc_watch_fn,
|
||||
thread_name_filter=thread_name_filter,
|
||||
log_usage=log_usage)
|
||||
|
||||
self._grpc_debug_server_addresses = grpc_debug_server_addresses
|
||||
self._sent_graph_version = -sys.maxint
|
||||
|
||||
def before_run(self, run_context):
|
||||
self._sent_graph_version = grpc_wrapper.publish_traceback(
|
||||
self._grpc_debug_server_addresses, run_context.session.graph,
|
||||
run_context.original_args.feed_dict, run_context.original_args.fetches,
|
||||
self._sent_graph_version)
|
||||
return super(TensorBoardDebugHook, self).before_run(run_context)
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.debug.cli import debugger_cli_common
|
||||
from tensorflow.python.debug.cli import profile_analyzer_cli
|
||||
from tensorflow.python.debug.cli import stepper_cli
|
||||
from tensorflow.python.debug.cli import ui_factory
|
||||
from tensorflow.python.debug.lib import common
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.debug.wrappers import framework
|
||||
|
||||
@ -464,7 +465,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
feed_key = None
|
||||
feed_value = None
|
||||
for key in self._feed_dict:
|
||||
key_name = cli_shared.get_graph_element_name(key)
|
||||
key_name = common.get_graph_element_name(key)
|
||||
if key_name == tensor_name:
|
||||
feed_key = key_name
|
||||
feed_value = self._feed_dict[key]
|
||||
@ -561,7 +562,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
list(self._tensor_filters.keys()))
|
||||
if self._feed_dict:
|
||||
# Register tab completion for feed_dict keys.
|
||||
feed_keys = [cli_shared.get_graph_element_name(key)
|
||||
feed_keys = [common.get_graph_element_name(key)
|
||||
for key in self._feed_dict.keys()]
|
||||
curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)
|
||||
|
||||
|
@ -495,14 +495,13 @@ class GraphModeFunction(object):
|
||||
def _get_defun_inputs(args):
|
||||
"""Maps the inputs args to graph inputs."""
|
||||
ret = []
|
||||
for a in args:
|
||||
flat_args = nest.flatten(args)
|
||||
for a in flat_args:
|
||||
if isinstance(a, ops.Tensor):
|
||||
ret.append(graph_placeholder(a.dtype, a.shape))
|
||||
elif type(a) in (tuple, list):
|
||||
ret.append(_get_defun_inputs(a))
|
||||
else:
|
||||
ret.append(a)
|
||||
return tuple(ret) if type(args) is tuple else ret
|
||||
return nest.pack_sequence_as(args, ret)
|
||||
|
||||
|
||||
def _defun_internal(name, func, args, kwds):
|
||||
@ -582,8 +581,10 @@ def _cache_key(x):
|
||||
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
|
||||
if isinstance(x, np.ndarray):
|
||||
return ("array", x.shape, tuple(x.reshape(-1)))
|
||||
if type(x) in (list, tuple):
|
||||
if isinstance(x, (list, tuple)):
|
||||
return tuple([_cache_key(a) for a in x])
|
||||
if isinstance(x, dict):
|
||||
return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items())
|
||||
return x
|
||||
|
||||
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
@ -57,6 +59,20 @@ class FunctionTest(test.TestCase):
|
||||
out = sq(t)
|
||||
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
|
||||
|
||||
def testNestedInputsGraphMode(self):
|
||||
matmul = function.defun(math_ops.matmul)
|
||||
|
||||
pair = collections.namedtuple('pair', ['a', 'b'])
|
||||
|
||||
@function.defun
|
||||
def a_times_b(inputs):
|
||||
return matmul(inputs.a['a'], inputs.b['b'])
|
||||
|
||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
out = a_times_b(pair({'a': t}, {'b': t}))
|
||||
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
|
||||
|
||||
def testGraphModeWithGradients(self):
|
||||
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
||||
|
||||
@ -83,6 +99,22 @@ class FunctionTest(test.TestCase):
|
||||
out = sq_op(t)
|
||||
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
|
||||
|
||||
def testNestedInputsDefunOpGraphMode(self):
|
||||
matmul = function.defun(math_ops.matmul)
|
||||
|
||||
pair = collections.namedtuple('pair', ['a', 'b'])
|
||||
def a_times_b(inputs):
|
||||
return matmul(inputs.a['a'], inputs.b['b'])
|
||||
|
||||
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
inputs = pair({'a': t}, {'b': t})
|
||||
sq_op = function.make_defun_op(a_times_b, inputs)
|
||||
|
||||
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
|
||||
out = sq_op(inputs)
|
||||
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
|
||||
|
||||
def testNestedOutputDefunOpGraphMode(self):
|
||||
matmul = function.defun(math_ops.matmul)
|
||||
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/python/lib/core/bfloat16.h"
|
||||
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
@ -477,8 +479,61 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType, typename Functor>
|
||||
void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
|
||||
void* data) {
|
||||
const char* i0 = args[0];
|
||||
const char* i1 = args[1];
|
||||
char* o = args[2];
|
||||
for (npy_intp k = 0; k < *dimensions; k++) {
|
||||
InType x = *reinterpret_cast<const InType*>(i0);
|
||||
InType y = *reinterpret_cast<const InType*>(i1);
|
||||
*reinterpret_cast<OutType*>(o) = Functor()(x, y);
|
||||
i0 += steps[0];
|
||||
i1 += steps[1];
|
||||
o += steps[2];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Functor>
|
||||
void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
|
||||
void* data) {
|
||||
BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
|
||||
}
|
||||
|
||||
struct Bfloat16EqFunctor {
|
||||
npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
|
||||
};
|
||||
struct Bfloat16NeFunctor {
|
||||
npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
|
||||
};
|
||||
struct Bfloat16LtFunctor {
|
||||
npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
|
||||
};
|
||||
struct Bfloat16GtFunctor {
|
||||
npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
|
||||
};
|
||||
struct Bfloat16LeFunctor {
|
||||
npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
|
||||
};
|
||||
struct Bfloat16GeFunctor {
|
||||
npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
|
||||
};
|
||||
|
||||
// Initializes the module.
|
||||
bool Initialize() {
|
||||
// It's critical to import umath to avoid crash in open source build.
|
||||
import_umath1(false);
|
||||
|
||||
Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
|
||||
if (!numpy_str) {
|
||||
return false;
|
||||
}
|
||||
Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
|
||||
if (!numpy) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We hit a mysterious crash if we haven't initialized numpy before this:
|
||||
PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
|
||||
|
||||
@ -536,6 +591,57 @@ bool Initialize() {
|
||||
/*cast_is_safe=*/true)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Register ufuncs
|
||||
auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
|
||||
const std::array<int, 3>& types) {
|
||||
Safe_PyObjectPtr ufunc_obj =
|
||||
make_safe(PyObject_GetAttrString(numpy.get(), name));
|
||||
if (!ufunc_obj) {
|
||||
return false;
|
||||
}
|
||||
PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
|
||||
if (types.size() != ufunc->nargs) {
|
||||
PyErr_Format(PyExc_AssertionError,
|
||||
"ufunc %s takes %d arguments, loop takes %lu", name,
|
||||
ufunc->nargs, types.size());
|
||||
return false;
|
||||
}
|
||||
if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
|
||||
const_cast<int*>(types.data()),
|
||||
nullptr) < 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Comparisons
|
||||
const std::array<int, 3> compare_types = {npy_bfloat16_, npy_bfloat16_,
|
||||
NPY_BOOL};
|
||||
|
||||
if (!register_ufunc("equal", CompareUFunc<Bfloat16EqFunctor>,
|
||||
compare_types)) {
|
||||
return false;
|
||||
}
|
||||
if (!register_ufunc("not_equal", CompareUFunc<Bfloat16NeFunctor>,
|
||||
compare_types)) {
|
||||
return false;
|
||||
}
|
||||
if (!register_ufunc("less", CompareUFunc<Bfloat16LtFunctor>, compare_types)) {
|
||||
return false;
|
||||
}
|
||||
if (!register_ufunc("greater", CompareUFunc<Bfloat16GtFunctor>,
|
||||
compare_types)) {
|
||||
return false;
|
||||
}
|
||||
if (!register_ufunc("less_equal", CompareUFunc<Bfloat16LeFunctor>,
|
||||
compare_types)) {
|
||||
return false;
|
||||
}
|
||||
if (!register_ufunc("greater_equal", CompareUFunc<Bfloat16GeFunctor>,
|
||||
compare_types)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -172,6 +172,24 @@ class Bfloat16NumPyTest(test.TestCase):
|
||||
self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
|
||||
self.assertAllEqual(x, x)
|
||||
self.assertAllClose(x, x)
|
||||
self.assertTrue((x == x).all())
|
||||
|
||||
def testComparisons(self):
|
||||
x = np.array([401408, 7, -32], dtype=np.float32)
|
||||
bx = x.astype(bfloat16)
|
||||
y = np.array([82432, 7, 0], dtype=np.float32)
|
||||
by = y.astype(bfloat16)
|
||||
self.assertAllEqual(x == y, bx == by)
|
||||
self.assertAllEqual(x != y, bx != by)
|
||||
self.assertAllEqual(x < y, bx < by)
|
||||
self.assertAllEqual(x > y, bx > by)
|
||||
self.assertAllEqual(x <= y, bx <= by)
|
||||
self.assertAllEqual(x >= y, bx >= by)
|
||||
|
||||
def testEqual2(self):
|
||||
a = np.array([401408], bfloat16)
|
||||
b = np.array([82432], bfloat16)
|
||||
self.assertFalse(a.__eq__(b))
|
||||
|
||||
def testCasts(self):
|
||||
for dtype in [
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include <Python.h>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "numpy/ufuncobject.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -25,6 +25,7 @@ py_library(
|
||||
":main_op",
|
||||
":signature_constants",
|
||||
":signature_def_utils",
|
||||
":simple_save",
|
||||
":tag_constants",
|
||||
":utils",
|
||||
"//tensorflow/python:util",
|
||||
@ -89,6 +90,23 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "simple_save",
|
||||
srcs = [
|
||||
"simple_save.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":builder",
|
||||
":signature_constants",
|
||||
":signature_def_utils",
|
||||
":tag_constants",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "main_op",
|
||||
srcs = [
|
||||
@ -198,6 +216,22 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "simple_save_test",
|
||||
size = "small",
|
||||
srcs = ["simple_save_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":loader",
|
||||
":signature_constants",
|
||||
":simple_save",
|
||||
":tag_constants",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
|
@ -30,6 +30,9 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model import utils
|
||||
# pylint: enable=unused-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.saved_model.simple_save import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
@ -41,6 +44,7 @@ _allowed_symbols = [
|
||||
"main_op",
|
||||
"signature_constants",
|
||||
"signature_def_utils",
|
||||
"simple_save",
|
||||
"tag_constants",
|
||||
"utils",
|
||||
]
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""SavedModel utility functions."""
|
||||
"""SavedModel simple save functionality."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -39,7 +39,7 @@ def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None):
|
||||
to configure a SavedModel, this method has a few practical implications:
|
||||
- It will be treated as a graph for inference / serving (i.e. uses the tag
|
||||
`tag_constants.SERVING`)
|
||||
- The saved model will load in TensorFlow Serving and supports the
|
||||
- The SavedModel will load in TensorFlow Serving and supports the
|
||||
[Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto).
|
||||
To use the Classify, Regress, or MultiInference APIs, please
|
||||
use either
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user