Merge commit for internal changes

This commit is contained in:
Erich Elsen 2017-12-20 11:01:08 -08:00
commit 15283d90e7
102 changed files with 4784 additions and 595 deletions

View File

@ -42,6 +42,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
],
}),
)
@ -73,6 +74,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -2618,4 +2619,54 @@ void TF_SessionPRun(TF_Session* session, const char* handle,
output_values, target_names, nullptr, status);
}
TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
tensorflow::OpList op_list;
if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
status->status = InvalidArgument("Unparseable OpList");
return nullptr;
}
status->status = Status::OK();
return new TF_ApiDefMap(op_list);
}
void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
size_t text_len, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"ApiDefMap is not supported in Android.");
#else
mutex_lock l(api_def_map->lock);
if (api_def_map->update_docs_called) {
status->status = FailedPrecondition(
"TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
"called.");
return;
}
string api_def_text(text, text_len);
status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
#endif // __ANDROID__
}
TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
size_t name_len, TF_Status* status) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"ApiDefMap is not supported in Android.");
return nullptr;
#else
mutex_lock l(api_def_map->lock);
if (!api_def_map->update_docs_called) {
api_def_map->api_def_map.UpdateDocs();
api_def_map->update_docs_called = true;
}
string name_str(name, name_len);
const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
TF_Buffer* ret = TF_NewBuffer();
status->status = MessageToBuffer(*api_def, ret);
return ret;
#endif // __ANDROID__
}
} // end extern "C"

View File

@ -1518,6 +1518,49 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle);
// in this address space.
TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList();
// TF_ApiDefMap encapsulates a collection of API definitions for an operation.
//
// This object maps the name of a TensorFlow operation to a description of the
// API to generate for it, as defined by the ApiDef protocol buffer (
// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto)
//
// The ApiDef messages are typically used to generate convenience wrapper
// functions for TensorFlow operations in various language bindings.
typedef struct TF_ApiDefMap TF_ApiDefMap;
// Creates a new TF_ApiDefMap instance.
//
// Params:
// op_list_buffer - TF_Buffer instance containing serialized OpList
// protocol buffer. (See
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
// for the OpList proto definition).
// status - Set to OK on success and an appropriate error on failure.
TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer,
TF_Status* status);
// Deallocates a TF_ApiDefMap.
TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap);
// Add ApiDefs to the map.
//
// `text` corresponds to a text representation of an ApiDefs protocol message.
// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto).
//
// The provided ApiDefs will be merged with existing ones in the map, with
// precedence given to the newly added version in case of conflicts with
// previous calls to TF_ApiDefMapPut.
TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map,
const char* text, size_t text_len,
TF_Status* status);
// Returns a serialized ApiDef protocol buffer for the TensorFlow operation
// named `name`.
TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map,
const char* name,
size_t name_len,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -24,6 +24,9 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#ifndef __ANDROID__
#include "tensorflow/core/framework/op_gen_lib.h"
#endif
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -158,6 +161,22 @@ struct TF_Function {
tensorflow::FunctionDef fdef;
};
struct TF_ApiDefMap {
explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
:
#ifndef __ANDROID__
api_def_map(op_list),
#endif
update_docs_called(false) {
}
#ifndef __ANDROID__
tensorflow::ApiDefMap api_def_map GUARDED_BY(lock);
#endif
bool update_docs_called GUARDED_BY(lock);
tensorflow::mutex lock;
};
namespace tensorflow {
class TensorCApi {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) {
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
}
TEST(TestApiDef, TestCreateApiDef) {
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TF_Buffer op_list_buf = TF_GetOpList(lib);
status = TF_NewStatus();
auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
string op_name = "TestCApi";
status = TF_NewStatus();
auto* api_def_buf =
TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
tensorflow::ApiDef api_def;
EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
EXPECT_EQ(op_name, api_def.graph_op_name());
EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
TF_DeleteLibraryHandle(lib);
}
TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
TF_Status* status = TF_NewStatus();
TF_Library* lib =
TF_LoadLibrary("tensorflow/c/test_op.so", status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TF_Buffer op_list_buf = TF_GetOpList(lib);
status = TF_NewStatus();
auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
string api_def_overwrites = R"(op: <
graph_op_name: "TestCApi"
summary: "New summary"
>
)";
status = TF_NewStatus();
TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
api_def_overwrites.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
string op_name = "TestCApi";
status = TF_NewStatus();
auto* api_def_buf =
TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
tensorflow::ApiDef api_def;
EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
EXPECT_EQ(op_name, api_def.graph_op_name());
EXPECT_EQ("New summary", api_def.summary());
TF_DeleteBuffer(api_def_buf);
TF_DeleteApiDefMap(api_def_map);
TF_DeleteLibraryHandle(lib);
}
#undef EXPECT_TF_META
} // namespace

View File

@ -33,7 +33,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
],
}),
}) + ["//tensorflow/core:gpu_runtime"],
)
tf_cuda_library(
@ -55,6 +55,10 @@ tf_cuda_library(
tf_cuda_cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
tags = [
"guitar",
"multi_gpu",
],
deps = [
":c_api",
"//tensorflow/core:lib",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/runtime.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
@ -167,18 +168,6 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
if (is_same_device) {
return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
}
const bool src_cpu = IsCPU(srcd);
if (src_cpu == dst_cpu) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"TFE_TensorHandleCopyToDevice requires either the source "
"TFE_TensorHandle be on or the destination device be on CPU "
"or be the same (they are ",
DeviceName(srcd), " and ", DeviceName(dstd), " in this call)")
.c_str());
return nullptr;
}
tensorflow::Tensor* src = &(h->t);
if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) {
TF_SetStatus(
@ -189,26 +178,19 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
.c_str());
return nullptr;
}
if (src_cpu) {
tensorflow::Tensor dst(
dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(),
src->shape());
if (src->shape().num_elements() == 0) {
return new TFE_TensorHandle(dst, dstd);
}
tensorflow::Notification n;
dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice(
src, dstd, &dst, [status, &n](const tensorflow::Status& s) {
status->status = s;
n.Notify();
});
n.WaitForNotification();
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd)
: nullptr;
tensorflow::Tensor dst(dstd->GetAllocator(tensorflow::AllocatorAttributes()),
src->dtype(), src->shape());
if (src->shape().num_elements() == 0) {
return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
}
tensorflow::DeviceContext* src_device_context = nullptr;
if (!IsCPU(srcd)) {
src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
}
tensorflow::DeviceContext* dst_device_context = nullptr;
if (!dst_cpu) {
dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
}
CHECK(dst_cpu);
tensorflow::Tensor dst(src->dtype(), src->shape());
tensorflow::Notification n;
// TODO(ashankar): The Sync() call below may be more aggressive than
// necessary. It is based on knowledge of implementation details - that
// GPU devices are implemented using 3 streams - one for host->device copies,
@ -217,16 +199,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
// but more than necessary (since it waits for operations that might have
// nothing to do with this tensor to complete).
status->status = srcd->Sync();
if (!status->status.ok()) return nullptr;
srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU(
src, "IGNORE_MY_TENSOR_NAME", srcd, &dst,
[status, &n](const tensorflow::Status& s) {
status->status = s;
n.Notify();
});
tensorflow::Notification n;
tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
srcd, dstd, tensorflow::AllocatorAttributes(),
tensorflow::AllocatorAttributes(), src, &dst,
[status, &n](const tensorflow::Status& s) {
status->status = s;
n.Notify();
});
n.WaitForNotification();
return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr)
: nullptr;
return (TF_GetCode(status) == TF_OK)
? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
: nullptr;
}
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,

View File

@ -216,6 +216,64 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
const char* kCPUDevice = "CPU:0";
if (num_devices < 3) {
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
return;
}
const string gpu_1_name(TF_DeviceListName(devices, 1, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
const string gpu_2_name(TF_DeviceListName(devices, 2, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice(
hdevice, ctx, gpu_2_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_DeleteTensorHandle(hdevice);
// Copy back to CPU
TFE_TensorHandle* hcopy =
TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
TFE_DeleteTensorHandle(hdevice2);
// Ensure that the contents are the same!
TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get());
TFE_DeleteTensorHandle(hcopy);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK);
EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy));
EXPECT_EQ(
0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t)));
TF_DeleteTensor(tcopy);
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, TensorHandleSilentCopy) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -248,6 +248,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

File diff suppressed because it is too large Load Diff

View File

@ -48,6 +48,16 @@ typedef std::function<Status(
// 'group_attribute' must be a string valued-attribute that names the new
// functions to introduce.
//
// 'outside_compilation_attribute' must be a string-valued attribute that is
// used to tag nodes within a subgraph to be part of an 'outside_compilation'
// cluster within the subgraph. A cluster is formed from the set of nodes with
// the same value of outside_compilation_subgraph and group_attribute. The nodes
// in an outside_compilation cluster are left in the original graph. Edges
// crossing from the subgraph to an outside_compilation cluster nested in the
// subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
// crossing from an outside_compilation cluster into its enclosing subgraph are
// lifted into a SendFromHost/RecvFromHost pair of nodes.
//
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
//
@ -64,10 +74,10 @@ typedef std::function<Status(
// dep from B. Originally D must run after C, post-transformation this
// dependency is lost.
Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
FunctionLibraryDefinition* library);
string group_attribute, string outside_compilation_attribute,
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
bool parallel_checking, bool reuse_existing_functions,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.

View File

@ -36,7 +36,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
if (diff) {
*diff = strings::StrCat("Definition mismatch for function ",
a.signature().name(), ", expected:\n",
a.DebugString());
a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
@ -82,6 +82,24 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
<< diff << "\nActual: " << actual.DebugString(); \
} while (false)
// TODO(misard): remove these fake registrations once there are real Ops to be
// compiled.
REGISTER_OP("_XlaSendToHost")
.Input("input: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaRecvFromHost")
.Output("output: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaSendFromHost")
.Input("input: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaRecvAtHost")
.Output("output: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("InputTest").Output("o: float");
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
@ -98,10 +116,32 @@ REGISTER_OP("AddNLikeTest")
.SetIsCommutative()
.SetIsAggregate();
Node* NoOp(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("NoOp", opts);
}
Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTest", opts);
}
Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
}
Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", std::move(a), opts);
}
@ -145,7 +185,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
if (!s.ok()) return s;
std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph,
/*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/false,
/*reuse_existing_functions=*/false,
@ -178,6 +218,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) {
FunctionDefLibrary library_out = library_in;
TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
}
@ -230,7 +271,6 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
@ -342,9 +382,9 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph,
&library));
"_cluster", "_outside", graph_before_encapsulation,
/*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false,
/*reuse_existing_functions=*/false, &graph, &library));
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
@ -374,9 +414,9 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
/*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph,
&library));
"_cluster", "_outside", graph_before_encapsulation,
/*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true,
/*reuse_existing_functions=*/false, &graph, &library));
std::vector<string> expected_nodes = {
"add1", "add2", "cluster1", "cluster1_parallel_check/_0",
@ -432,7 +472,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before,
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
@ -477,7 +517,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before,
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
@ -502,5 +542,678 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
EXPECT_EQ(1, guaranteed_consts);
}
// Test with one function to transform and one outside_compilation cluster.
TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d = Binary(b, c,
b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"C:o:0", "c:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"c"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts().WithName("E").WithControlInputs({recv, b}));
Node* send = SendFromHost({e}, {DT_FLOAT},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* s = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one function to transform and two outside_compilation clusters.
TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g = Binary(e, f,
b1.opts()
.WithName("G")
.WithControlInputs({e, f})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* h = Binary(d, e,
b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
Binary(g, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
{{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O2_send"},
"_XlaSendToHost",
{"D:o:0", "F:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"F"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"C:o:0", "D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"D"}},
{{"outside_compilation_O2_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O2_send"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"i_0_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv1 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
Node* send1 = SendFromHost({e}, {DT_FLOAT},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* recv2 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O2_recv"));
Node* g = Binary(e, ops::NodeOut(recv2, 1),
b2.opts().WithName("G").WithControlInputs({recv2, e}));
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
Node* send2 = SendFromHost(
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
Node* s = NoOp(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, send1, recv2, send2}));
Binary(g, call, b2.opts().WithName("J").WithControlInput(s));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform, each with one outside_compilation
// cluster.
TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g = Binary(e, f,
b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
"_encapsulate", "F2"));
Node* h = Binary(d, g,
b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
Node* i =
Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(g, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"f_0_retval:float", "d_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"C:o:0", "D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{"D"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"e_0_arg:float", "f_0_arg:float"},
{"g_0_retval:float", "i_0_retval:float"}, {},
{
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"G:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
Node* send1 = SendFromHost({e}, {DT_FLOAT},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Node* recv2 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
b2.opts().WithName("H").WithControlInput(s1));
Node* send2 = SendFromHost(
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(e).Input(call1);
Node* call2 = b2.opts()
.WithControlInputs({s1, e, call1})
.FinalizeBuilder(&node_builder2);
Node* s2 = NoOp(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}));
Binary(call2, ops::NodeOut(call2, 1),
b2.opts().WithName("J").WithControlInput(s2));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no inputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f =
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Unary(f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts().WithName("E"));
Node* send1 = SendFromHost(
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1));
Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no data inputs but has a
// control input from the compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithControlInput(d)
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f =
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Unary(f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({})}},
{"D"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 =
RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
Node* send1 = SendFromHost(
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1));
Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no data outputs but has a
// control output to the compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
Node* send1 = SendFromHost({}, {},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts().WithName("E"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(e, call1, b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
} // namespace
} // namespace tensorflow

View File

@ -41,6 +41,7 @@ limitations under the License.
namespace tensorflow {
const char* const kXlaClusterAttr = "_XlaCluster";
const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
namespace {

View File

@ -28,6 +28,10 @@ namespace tensorflow {
// encapsulate subgraphs pass.
extern const char* const kXlaClusterAttr;
// The attribute that marks nodes in a cluster to be placed outside the xla
// compilation by the encapsulate subgraphs pass.
extern const char* const kXlaOutsideCompilationAttr;
// Pass that marks a subset of operators in the graph with attribute
// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
class MarkForCompilationPass : public GraphOptimizationPass {

View File

@ -270,8 +270,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return false;
}
/* static */ tensorflow::gtl::ArraySlice<const int64>
LayoutUtil::PaddedDimensions(const Shape& shape) {
/* static */ tensorflow::gtl::ArraySlice<int64> LayoutUtil::PaddedDimensions(
const Shape& shape) {
CHECK(IsDense(shape));
return AsInt64Slice(shape.layout().padded_dimensions());
}

View File

@ -96,7 +96,7 @@ class LayoutUtil {
// Returns the padded_dimensions array for the given Shape. Requires that the
// shape is an array and has a dense layout.
static tensorflow::gtl::ArraySlice<const int64> PaddedDimensions(
static tensorflow::gtl::ArraySlice<int64> PaddedDimensions(
const Shape& shape);
// Returns the given index of the padded_dimensions array for the given Shape.

View File

@ -341,7 +341,7 @@ class Literal {
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithMonotonicDim0MajorLayout(
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
// Creates a new literal from an array. The variants not ending with
@ -1233,10 +1233,9 @@ void Literal::PopulateWithValue(NativeT value,
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal>
Literal::CreateFullWithMonotonicDim0MajorLayout(
/* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
Shape this_shape = ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
auto literal = MakeUnique<Literal>();
*literal->mutable_shape() = this_shape;

View File

@ -46,6 +46,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",

View File

@ -166,6 +166,18 @@ ComputationDataHandle LocalComputationBuilder::Dot(
return builder_.Dot(lhs, rhs);
}
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
dimension_numbers);
}
ComputationDataHandle LocalComputationBuilder::ConvertElementType(
const ComputationDataHandle& operand, PrimitiveType new_element_type) {
return builder_.ConvertElementType(operand, new_element_type);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
@ -113,6 +114,14 @@ class LocalComputationBuilder {
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs);
ComputationDataHandle ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
PrimitiveType new_element_type);

View File

@ -22,18 +22,19 @@ limitations under the License.
//
// C++ Python
// -------------------------------------+---------------------------------------
// ComputationDataHandle <-> long
// ArraySlice<int64> <- sequence of long
// ArraySlice<ComputationDataHandle> <- sequence of long
// ComputationDataHandle <-> int
// ArraySlice<int64> <- sequence of int
// ArraySlice<ComputationDataHandle> <- sequence of int
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape <-> pair holding (dtype, dimensions)
// std::vector<Shape> <- sequence of shape information pairs
// PrimitiveType <- int
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
// ConvolutionDimensionNumbers proto <- corresponding Python proto
//
// Arrows indicate whether a conversion only ever occurs in one
// direction, or whether it is maintained bidirectionally. Also,
// "long" and "int" denote the Python types so named, not C.
// direction, or whether it is maintained bidirectionally.
//
// The Python objects corresponding to C++ Literals have the type:
//
@ -113,6 +114,27 @@ limitations under the License.
using namespace xla;
using namespace xla::swig;
namespace xla {
namespace swig {
bool GetIntAttr(PyObject* o, const char* field, int64* result) {
PyObject* fo = PyObject_GetAttrString(o, field);
if (!fo) {
return false;
}
const int64 value = numpy::PyIntOrPyLongToLong(fo);
if (value == -1 && PyErr_Occurred()) {
Py_DECREF(fo);
return false;
}
Py_DECREF(fo);
*result = value;
return true;
}
}
}
%}
// Required to use PyArray_* functions.
@ -278,6 +300,189 @@ tensorflow::ImportNumpy();
$1 = static_cast<PrimitiveType>(value);
}
// ArraySlice<pair<int64, in64>>
%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
(std::vector<std::pair<int64, int64> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
temps.reserve(size);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
if (!o) {
return NULL;
}
PyObject* first = PyTuple_GetItem(o, 0);
if (!first) {
Py_DECREF(o);
return NULL;
}
PyObject* first_pyint = numpy::PyNumberToPyInt(first);
if (!first_pyint) {
PyErr_SetString(
PyExc_TypeError,
"First pair item cannot be converted to int");
Py_DECREF(o);
return NULL;
}
PyObject* second = PyTuple_GetItem(o, 1);
if (!second) {
Py_DECREF(o);
Py_DECREF(first_pyint);
return NULL;
}
PyObject* second_pyint = numpy::PyNumberToPyInt(second);
if (!second_pyint) {
PyErr_SetString(
PyExc_TypeError,
"Second pair item cannot be converted to int");
Py_DECREF(o);
Py_DECREF(first_pyint);
return NULL;
}
const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
if (first_value == -1 && PyErr_Occurred()) {
Py_DECREF(o);
Py_DECREF(first_pyint);
Py_DECREF(second_pyint);
return NULL;
}
const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
if (second_value == -1 && PyErr_Occurred()) {
Py_DECREF(o);
Py_DECREF(first_pyint);
Py_DECREF(second_pyint);
return NULL;
}
temps.push_back(std::make_pair(first_value, second_value));
Py_DECREF(o);
}
$1 = temps;
}
// ConvolutionDimensionNumbers
%typemap(in) const ConvolutionDimensionNumbers&
(ConvolutionDimensionNumbers dimension_numbers) {
int64 value;
if (!GetIntAttr($input, "input_batch_dimension", &value)) {
return NULL;
}
dimension_numbers.set_input_batch_dimension(value);
if (!GetIntAttr($input, "input_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_input_feature_dimension(value);
if (!GetIntAttr($input, "output_batch_dimension", &value)) {
return NULL;
}
dimension_numbers.set_output_batch_dimension(value);
if (!GetIntAttr($input, "output_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_output_feature_dimension(value);
if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_kernel_output_feature_dimension(value);
if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
return NULL;
}
dimension_numbers.set_kernel_input_feature_dimension(value);
PyObject* o;
int length;
o = PyObject_GetAttrString($input, "input_spatial_dimensions");
if (!o) {
return NULL;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
return NULL;
}
dimension_numbers.add_input_spatial_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(o);
o = PyObject_GetAttrString($input, "kernel_spatial_dimensions");
if (!o) {
return NULL;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
return NULL;
}
dimension_numbers.add_kernel_spatial_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(o);
o = PyObject_GetAttrString($input, "output_spatial_dimensions");
if (!o) {
return NULL;
}
length = PySequence_Size(o);
if (length == -1) {
Py_DECREF(o);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(o, i);
if (!item) {
Py_DECREF(o);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(o);
return NULL;
}
dimension_numbers.add_output_spatial_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(o);
$1 = &dimension_numbers;
}
%ignoreall
%unignore xla;
%unignore xla::swig;
@ -314,6 +519,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Lt;
%unignore xla::swig::LocalComputationBuilder::Le;
%unignore xla::swig::LocalComputationBuilder::Dot;
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
%unignore xla::swig::LocalComputationBuilder::Add;
%unignore xla::swig::LocalComputationBuilder::Sub;
%unignore xla::swig::LocalComputationBuilder::Mul;

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum # pylint: disable=g-bad-import-order
import itertools
import numpy as np
@ -25,6 +26,12 @@ import numpy as np
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.compiler.xla.python import pywrap_xla as c_api
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
_UNARY_OPS = [
'Not',
'Abs',
@ -564,6 +571,79 @@ class ComputationBuilder(object):
return _wrap_data_handle(
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
def Conv(self, lhs, rhs, window_strides, padding):
"""Enqueues a Conv operation onto the computation.
Args:
lhs: ComputationDataHandle for the rank N+2 array of inputs.
rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
Returns: a ComputationDataHandle representing the Conv operation.
"""
if padding == PaddingType.SAME:
lhs_dims = self.GetShape(lhs).dimensions()
rhs_dims = self.GetShape(rhs).dimensions()
in_shape, filter_shape = lhs_dims[2:], rhs_dims[2:]
out_shape = np.ceil(np.true_divide(in_shape, window_strides)).astype(int)
pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0)
for out_size, stride, filter_size, in_size
in zip(out_shape, window_strides, filter_shape, in_shape)]
pads = [(pad_size // 2, pad_size - pad_size // 2)
for pad_size in pad_sizes]
else:
pads = [(0, 0)] * len(window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return _wrap_data_handle(
self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
_unwrap_data_handle(rhs),
window_strides,
pads,
(),
(),
dimension_numbers))
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
lhs: ComputationDataHandle for the rank N+2 array of inputs.
rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
window_strides: length-N array-like of kernel strides.
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
"""
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return _wrap_data_handle(
self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
_unwrap_data_handle(rhs),
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers))
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
nd = num_spatial_dims
dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
dimension_numbers.input_batch_dimension = 0
dimension_numbers.input_feature_dimension = 1
dimension_numbers.output_batch_dimension = 0
dimension_numbers.output_feature_dimension = 1
dimension_numbers.kernel_output_feature_dimension = 0
dimension_numbers.kernel_input_feature_dimension = 1
dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
return dimension_numbers
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.

View File

@ -386,6 +386,46 @@ class SingleOpTest(LocalComputationTest):
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testConvF32Same(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(c.Constant(lhs), c.Constant(rhs),
[1, 1], xla_client.PaddingType.SAME)
result = np.array([[[[640., 700., 760., 300.],
[880., 940., 1000., 380.],
[1120., 1180., 1240., 460.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvF32Valid(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 2, 3, 4)
rhs = a(1, 2, 1, 2) * 10
c.Conv(c.Constant(lhs), c.Constant(rhs),
[2, 1], xla_client.PaddingType.VALID)
result = np.array([[[[640., 700., 760.],
[1120., 1180., 1240.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvWithGeneralPaddingF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])

View File

@ -1182,6 +1182,11 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
}
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
pad, HloInstruction::CreateBroadcast(pad->shape(),
pad->mutable_operand(1), {}));
}
// Eliminate nop pads (padding all zero), and replace a pad with negative
// padding with a pad with non-negative padding followed by a slice.
bool all_zero = true;
@ -1624,6 +1629,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* reduce_window) {
if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateBroadcast(reduce_window->shape(),
reduce_window->mutable_operand(1), {}));
}
auto operand = reduce_window->mutable_operand(0);
const Window& window = reduce_window->window();
auto function = reduce_window->to_apply();
@ -1694,7 +1705,6 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
auto operand = transpose->mutable_operand(0);
if (std::is_sorted(transpose->dimensions().begin(),
transpose->dimensions().end())) {
VLOG(10) << "deleting no-op transpose";
@ -1721,6 +1731,18 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution) {
auto lhs = convolution->mutable_operand(0);
auto rhs = convolution->mutable_operand(1);
if (ShapeUtil::HasZeroElements(lhs->shape()) ||
ShapeUtil::HasZeroElements(rhs->shape())) {
return ReplaceWithNewInstruction(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
{}));
}
const auto& window = convolution->window();
if (!enable_conv_simplification_) {
return Status::OK();
@ -1813,18 +1835,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
// We already checked feature_dimension is most minor, so data in input_shape
// and row-major {conv_width,input_channels} are bitwise identical.
const Shape new_input_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
input_shape.element_type(), {conv_width, input_channels});
const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
input_shape.element_type(), {conv_width, input_channels});
// We already checked input_feature_dimension is more major than
// output_feature_dimension, so data in filter_shape and row-major
// {input_channels,output_channels} are bitwise identical.
const Shape new_filter_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
filter_shape.element_type(), {input_channels, output_channels});
const Shape dot_output_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
convolution_shape.element_type(), {conv_width, output_channels});
const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
filter_shape.element_type(), {input_channels, output_channels});
const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
convolution_shape.element_type(), {conv_width, output_channels});
// We cannot insert bitcasts if the layouts will not be compatible.
// TODO(b/33178038): Consider inserting a transpose if a bitcast would be

View File

@ -816,6 +816,120 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
1);
}
TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs"));
HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs"));
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.add_input_spatial_dimensions(1);
dnums.set_input_feature_dimension(2);
dnums.set_output_batch_dimension(0);
dnums.add_output_spatial_dimensions(1);
dnums.set_output_feature_dimension(2);
dnums.add_kernel_spatial_dimensions(0);
dnums.set_kernel_input_feature_dimension(1);
dnums.set_kernel_output_feature_dimension(2);
Window window;
WindowDimension* dim = window.add_dimensions();
dim->set_size(3);
dim->set_padding_low(0);
dim->set_padding_high(0);
dim->set_stride(1);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
dim->set_window_reversal(false);
// Create add computation.
std::unique_ptr<HloModule> module = CreateNewModule();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
module->AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Convolution(lhs, rhs));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
Window window;
for (int64 i = 0; i < 2; ++i) {
WindowDimension* dim = window.add_dimensions();
dim->set_size(1);
dim->set_padding_low(1);
dim->set_padding_high(1);
dim->set_window_dilation(1);
dim->set_base_dilation(1);
}
// Create add computation.
std::unique_ptr<HloModule> module = CreateNewModule();
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* p0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "p0"));
HloInstruction* p1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
add_computation = module->AddEmbeddedComputation(builder.Build());
}
builder.AddInstruction(HloInstruction::CreateReduceWindow(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
window, add_computation));
module->AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::ReduceWindow(param, op::Constant()));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {3, 0}), "op"));
PaddingConfig padding;
for (int i = 0; i < 2; ++i) {
PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions();
dimension->set_edge_padding_low(1);
dimension->set_edge_padding_high(1);
dimension->set_interior_padding(0);
}
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
padding));
std::unique_ptr<HloModule> module = CreateNewModule();
module->AddEntryComputation(builder.Build());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Pad(param, op::Constant()));
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Broadcast(op::Constant()));
}
TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@ -1309,7 +1423,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) {
HloComputation::Builder builder(TestName());
HloInstruction* param0 =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}),
0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}),
"param0"));
HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary(

View File

@ -28,8 +28,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace se = ::perftools::gputools;
namespace xla {
StatusOr<GlobalDataHandle> AllocationTracker::Register(

View File

@ -96,6 +96,26 @@ class ConvolutionThunk : public Thunk {
return !best_algorithm_.has_value();
}
// Return true if scratch memory is needed to execute the thunk, that is
// either the best algorithm hasn't been chosen or the best algorithm is not
// the same as the no-scratch algorithm. This is because that the execution
// of the thunk is asynchronous, and the scratch allocator goes out of
// scope before the thunk finishes execution. Returning true tells the stream
// executor to make future thunks wait for this thunk to avoid reusing the
// deallocated scratch memory until this thunk is done with it.
bool ShouldBlockFutureThunks() {
if (!best_algorithm_.has_value()) {
return true;
}
const perftools::gputools::dnn::AlgorithmDesc& best_alg =
best_algorithm_->algorithm();
const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg =
best_algorithm_->algorithm_no_scratch();
return (!best_alg.is_default() || !no_scratch_best_alg.is_default() ||
!(best_alg == no_scratch_best_alg));
}
private:
tensorflow::Status ConvolveWithTune(
const perftools::gputools::dnn::BatchDescriptor& input_descriptor,

View File

@ -155,6 +155,7 @@ Status GpuExecutable::ExecuteThunks(
run_options->BorrowStream(main_stream->parent()->device_ordinal()));
}
std::map<int32, const Thunk*> last_blocking_thunk_for_stream;
std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
TF_RETURN_IF_ERROR(thunk->Initialize(*this));
@ -167,10 +168,17 @@ Status GpuExecutable::ExecuteThunks(
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
}
if (last_blocking_thunk_for_stream.count(stream_no)) {
stream->ThenWaitFor(FindOrDie(thunk_to_finish_event,
last_blocking_thunk_for_stream[stream_no])
.get());
}
// If this thunk requests it, wait for all currently-executing thunks to
// finish. This is useful e.g. if the thunk is about to perform autotuning.
if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) {
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
last_blocking_thunk_for_stream.clear();
}
profiler.StartOperation();
@ -178,11 +186,14 @@ Status GpuExecutable::ExecuteThunks(
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
if (thunk_schedule_->Depended(thunk)) {
if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) {
auto finish_event = MakeUnique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
if (thunk->ShouldBlockFutureThunks()) {
last_blocking_thunk_for_stream[stream_no] = thunk;
}
}
profiler.FinishOperation(thunk->hlo_instruction());
}

View File

@ -419,8 +419,8 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
(segs.size() == i ? shape.dimensions().size() : segs[i]),
1, std::multiplies<int64>()));
}
return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(),
dimensions);
return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
dimensions);
}
// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
@ -442,11 +442,13 @@ std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
}
}
auto segs = ConsecutiveSegments(perm);
Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a);
Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b);
Shape norm_a =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
Shape norm_b =
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
if (3 == segs.size() && 0 == perm[0]) {
Shape reduced_a = MergeDimensions(segs, norm_a);
Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
b.element_type(),
Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
return std::make_tuple(true, reduced_a, reduced_b);
@ -460,10 +462,11 @@ std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
return 3 == b.dimensions().size() &&
ShapeUtil::Compatible(
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a),
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
ShapeUtil::PermuteDimensions(
{0, 2, 1},
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b)));
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
b)));
}
// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
@ -495,9 +498,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
Shape input_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape());
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input.GetShape());
Shape output_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape());
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
output.GetShape());
input = input.CastToShape(input_shape, builder);
output = output.CastToShape(output_shape, builder);
@ -615,7 +620,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
builder))),
builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
ShapeUtil::MakeShapeWithDescendingLayout(
PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
builder);
const llvm_ir::IrArray::Index input_tile_origin = ({
@ -811,14 +816,15 @@ Status IrEmitterUnnested::EmitColumnReduction(
// input_shape to normalized_input_shape and a reshape from
// normalized_input_shape to input_matrix_shape.
const Shape normalized_input_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape);
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
const std::vector<int64> transpose_dimension_mapping(
input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
const Shape input_matrix_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
input_shape.element_type(), {height, width});
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{height, width});
const llvm_ir::IrArray::Index input_matrix_index(
{y, x}, input_matrix_shape, &ir_builder_);
const llvm_ir::IrArray::Index input_index =
@ -1054,13 +1060,14 @@ Status IrEmitterUnnested::EmitRowReduction(
// from input_shape to normalized_input_shape and a reshape from
// normalized_input_shape to input_3d_tensor_shape.
const Shape normalized_input_shape =
ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape);
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
input_shape);
auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
const std::vector<int64> transpose_dimension_mapping(
input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
const Shape input_3d_tensor_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
input_shape.element_type(), {depth, height, width});
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{depth, height, width});
const llvm_ir::IrArray::Index input_3d_tensor_index(
{z, y, x}, input_3d_tensor_shape, &ir_builder_);
const llvm_ir::IrArray::Index input_index =

View File

@ -83,6 +83,16 @@ class Thunk {
return false;
}
// Indicates whether thunks scheduled after this one should wait for this one
// to complete before running. For example, a convolution thunk creates a
// scratch allocator, then kicks off a convolution in cudnn via the stream
// executor. When the stream executor call returns, the scratch allocator goes
// out of scope, and the scratch memory is deallocated. In this case, the
// convolution thunk needs to return true so that future thunks wait for the
// convolution thunk to avoid reusing the deallocated memory until the
// convolution thunk is done with it.
virtual bool ShouldBlockFutureThunks() { return false; }
// Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's
// lifetime. Stream argument must be non-null.

View File

@ -1361,7 +1361,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
std::unique_ptr<Literal> arg_literal =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f);
Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
@ -1414,7 +1414,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 8.0f);
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
LiteralTestUtil::ExpectEqual(*result_literal, *result);
}

View File

@ -478,7 +478,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
} else if (instruction->opcode() == HloOpcode::kCustomCall) {
// Add constraints for kCustomCall instruction operands and instructions.
// For now we only support major-first layouts for all inputs and outputs.
Shape result_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
instruction->shape().element_type(),
AsInt64Slice(instruction->shape().dimensions()));
TF_RETURN_IF_ERROR(
@ -491,7 +491,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
}
Shape row_major_operand_shape =
ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
ShapeUtil::MakeShapeWithDescendingLayout(
operand_shape.element_type(),
AsInt64Slice(operand_shape.dimensions()));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(

View File

@ -189,20 +189,21 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
.ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
std::vector<int64> layout(dimensions.size());
std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
return MakeShapeWithLayout(element_type, dimensions, layout);
}
/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(
/* static */ Shape
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape) {
std::vector<int64> dims(shape.dimensions_size());
for (int i = 0; i < shape.dimensions_size(); ++i) {
dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
}
return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims);
return MakeShapeWithDescendingLayout(shape.element_type(), dims);
}
/* static */ void ShapeUtil::PopulateShape(
@ -1148,9 +1149,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
// as input_shape/output_shape and the dimension-0-major layout. These two
// shapes are used for conversion between logical linear indices and
// multi-dimensional indices.
Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout(
Shape input_shape_dim0_major = MakeShapeWithDescendingLayout(
input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout(
Shape output_shape_dim0_major = MakeShapeWithDescendingLayout(
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()));
for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) {

View File

@ -268,14 +268,18 @@ class ShapeUtil {
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major);
// Constructs a new shape with major-first layout.
static Shape MakeShapeWithMonotonicDim0MajorLayout(
// Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}).
static Shape MakeShapeWithDescendingLayout(
PrimitiveType element_type,
tensorflow::gtl::ArraySlice<int64> dimensions);
// Returns a new shape with major-first layout that has the same layout of
// elements with a different shape.
static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape);
// Returns a new Shape based on the given Shape with low-dimension-major
// layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions
// rearranged so that it has the same in-memory layout as the given shape.
//
// For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}.
static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
const Shape& shape);
// As MakeShape, but the object to write to is passed in.
static void PopulateShape(PrimitiveType element_type,

View File

@ -393,7 +393,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
auto shape = ShapeUtil::MakeShape(F32, input_dims);
std::unique_ptr<Literal> arg_literal =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(input_dims, 1.0f);
Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
@ -402,7 +402,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
std::unique_ptr<Literal> expected =
Literal::CreateFullWithMonotonicDim0MajorLayout<float>(output_dims, 9.0f);
Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
}

View File

@ -397,6 +397,7 @@ py_test(
srcs = ["scan_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@ -491,6 +492,25 @@ py_test(
],
)
py_test(
name = "unique_dataset_op_test",
size = "small",
srcs = ["unique_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/contrib/stateless",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//third_party/py/numpy",
],
)
py_test(
name = "zip_dataset_op_test",
size = "small",

View File

@ -735,6 +735,20 @@ class BatchDatasetSerializationTest(
lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
num_outputs)
def _build_dataset_dense_to_sparse(self, components):
return dataset_ops.Dataset.from_tensor_slices(components).map(
lambda x: array_ops.fill([x], x)).apply(
batching.dense_to_sparse_batch(4, [12]))
def testDenseToSparseBatchDatasetCore(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
num_outputs = len(components) // 4
self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
lambda: self._build_dataset_dense_to_sparse(diff_comp),
num_outputs)
class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):

View File

@ -21,6 +21,7 @@ import itertools
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase):
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
class ScanDatasetSerialzationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_elements):
return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
def testScanCore(self):
num_output = 5
self.run_core_tests(lambda: self._build_dataset(num_output),
lambda: self._build_dataset(2), num_output)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,96 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import unique
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class UniqueDatasetTest(test.TestCase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
Args:
dtype: The `dtype` of the elements in each test case.
test_cases: A list of pairs of lists. The first component is the test
input that will be passed to the transformation; the second component
is the expected sequence of outputs from the transformation.
"""
# The `current_test_case` will be updated when we loop over `test_cases`
# below; declare it here so that the generator can capture it once.
current_test_case = []
dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
dtype).apply(unique.unique())
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with self.test_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
for element in expected:
if dtype == dtypes.string:
element = compat.as_bytes(element)
self.assertAllEqual(element, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testSimpleInt(self):
for dtype in [dtypes.int32, dtypes.int64]:
self._testSimpleHelper(dtype, [
([], []),
([1], [1]),
([1, 1, 1, 1, 1, 1, 1], [1]),
([1, 2, 3, 4], [1, 2, 3, 4]),
([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
])
def testSimpleString(self):
self._testSimpleHelper(dtypes.string, [
([], []),
(["hello"], ["hello"]),
(["hello", "hello", "hello"], ["hello"]),
(["hello", "world"], ["hello", "world"]),
(["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
])
class UniqueSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def testUnique(self):
def build_dataset(num_elements, unique_elem_range):
return dataset_ops.Dataset.range(num_elements).map(
lambda x: x % unique_elem_range).apply(unique.unique())
self.run_core_tests(lambda: build_dataset(200, 100),
lambda: build_dataset(40, 100), 100)
if __name__ == "__main__":
test.main()

View File

@ -105,6 +105,7 @@ py_library(
"resampling.py",
"scan_ops.py",
"stats_ops.py",
"unique.py",
],
srcs_version = "PY2AND3",
deps = [

View File

@ -0,0 +1,82 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unique element dataset transformations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_dataset_ops
def unique():
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
Use this transformation to produce a dataset that contains one instance of
each unique element in the input. For example:
```python
dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
# Using `unique()` will drop the duplicate elements.
dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 }
```
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return UniqueDataset(dataset)
return _apply_fn
class UniqueDataset(dataset_ops.Dataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
super(UniqueDataset, self).__init__()
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
raise TypeError(
"`tf.contrib.data.unique()` only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
return gen_dataset_ops.unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)))
@property
def output_classes(self):
return self._input_dataset.output_classes
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types

View File

@ -110,6 +110,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:utils",
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",

View File

@ -22,11 +22,14 @@ import numpy as np
import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import utils
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@ -95,6 +98,18 @@ class SubGraphTest(test.TestCase):
filtered_list = sub_graph.filter_list(input_list)
self.assertEqual(filtered_list, [b])
def testVariableUses(self):
with ops.Graph().as_default():
var = variable_scope.get_variable('var', shape=[10, 10])
resource_var = variable_scope.get_variable(
'resource_var', shape=[10, 10], use_resource=True)
x = array_ops.zeros([3, 10])
z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
z1 = math_ops.matmul(x, resource_var)
sub_graph = utils.SubGraph((z0, z1))
self.assertEqual(2, sub_graph.variable_uses(var))
self.assertEqual(1, sub_graph.variable_uses(resource_var))
class UtilsTest(test.TestCase):
@ -253,6 +268,25 @@ class UtilsTest(test.TestCase):
np_inv = np.linalg.inv(x + damp * np.eye(size))
self.assertAllClose(sess.run(tf_inv), np_inv)
def testCrossReplicaMean(self):
"""Ensures that cross_replica_mean() executes only when num_shards > 1."""
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(4):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertNotEqual(mean, tensor)
with ops.Graph().as_default():
with tpu_function.tpu_shard_context(1):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
self.assertEqual(mean, tensor)
with ops.Graph().as_default():
with self.assertRaises(ValueError): # Outside of TPU context.
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
if __name__ == '__main__':
test.main()

View File

@ -196,6 +196,7 @@ py_library(
srcs = ["utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",

View File

@ -267,6 +267,10 @@ class FisherFactor(object):
new_cov = math_ops.add_n(
tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
# Synchronize value across all TPU cores.
if utils.on_tpu():
new_cov = utils.cross_replica_mean(new_cov)
return moving_averages.assign_moving_average(
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -27,6 +29,8 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
# Method used for inverting matrices.
POSDEF_INV_METHOD = "cholesky"
@ -226,11 +230,13 @@ class SubGraph(object):
"""
def __init__(self, outputs):
# Set of all ancestor Tensors, Ops to 'outputs'.
self._members = set()
self._recurse_add(outputs)
def _recurse_add(self, nodes):
"""Recursively adds all of nodes' ancestors."""
for node in nodes:
if node in self._members:
continue
@ -246,8 +252,25 @@ class SubGraph(object):
return node in self._members
def variable_uses(self, var):
"""Computes number of times a variable is used."""
return len(self._members.intersection(set(var.value().consumers())))
"""Computes number of times a variable is used.
Args:
var: Variable or ResourceVariable instance.
Returns:
Number of times a variable is used within this subgraph.
Raises:
ValueError: If 'var' is not a variable type.
"""
if isinstance(var, resource_variable_ops.ResourceVariable):
var = var.handle
elif isinstance(var, variables.Variable):
var = var.value()
else:
raise ValueError("%s does not appear to be a variable." % str(var))
return len(self._members.intersection(set(var.consumers())))
def filter_list(self, node_list):
"""Filters 'node_list' to nodes in this subgraph."""
@ -292,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
return dysdx
def on_tpu():
"""Returns True when building a TPU computation."""
return tpu_function.get_tpu_context().number_of_shards is not None
def cross_replica_mean(tensor, name=None):
"""Takes mean value of a Tensor across all TPU cores.
Args:
tensor: Tensor to be synchronized.
name: None or string. Name of Op.
Returns:
Average of Tensor across all TPU cores.
Raises:
ValueError: If called outside of TPU context.
"""
with ops.name_scope(name, "cross_replica_mean", [tensor]):
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
raise ValueError(
"Cannot take cross_replica_mean() outside of TPU Context.")
if num_shards == 1:
return tensor
return tpu_ops.cross_replica_sum(tensor / num_shards)
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.

View File

@ -179,6 +179,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
ConcatenateTensorBuffers<ArrayDataType::kInt64>(
input_arrays, concatenation_axis, &concatenated_array);
break;
case ArrayDataType::kString:
ConcatenateTensorBuffers<ArrayDataType::kString>(
input_arrays, concatenation_axis, &concatenated_array);
break;
default:
LOG(FATAL) << "ArrayDataType not supported";
}

View File

@ -52,6 +52,7 @@ using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
@ -135,6 +136,8 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kInt32;
else if (dtype == DT_INT64)
return ArrayDataType::kInt64;
else if (dtype == DT_STRING)
return ArrayDataType::kString;
else
LOG(INFO) << "Unsupported data type in placehoder op: " << dtype;
return ArrayDataType::kNone;
@ -236,6 +239,27 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
}
}
void ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
ImportShape(input_shape.dim(), output_array->mutable_shape());
int input_flat_size = 1;
for (int k = 0; k < input_shape.dim_size(); k++) {
input_flat_size *= input_shape.dim(k).size();
}
auto& output_string_data =
output_array->GetMutableBuffer<ArrayDataType::kString>().data;
output_string_data.resize(input_flat_size);
if (input_flat_size != input_tensor.string_val_size()) {
LOG(FATAL) << "Input_content string_val doesn't have the right "
"dimensions for this string tensor.";
}
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
}
// Count the number of inputs of a given node. If
// `tf_import_flags.drop_control_dependency` is true, count the number of
// non-control-dependency inputs.
@ -261,23 +285,30 @@ void ConvertConstOperator(const NodeDef& node,
const auto dtype = GetDataTypeAttr(node, "dtype");
auto& array = model->GetOrCreateArray(node.name());
array.data_type = dtype == DT_FLOAT
? ArrayDataType::kFloat
: dtype == DT_INT32
? ArrayDataType::kInt32
: dtype == DT_INT64 ? ArrayDataType::kInt64
: ArrayDataType::kNone;
if (dtype == DT_FLOAT) {
ImportFloatArray(tensor, &array);
} else if (dtype == DT_INT32) {
ImportInt32Array(tensor, &array);
} else if (dtype == DT_INT64) {
ImportInt64Array(tensor, &array);
} else {
// do nothing, silently ignore the Const data. For example, there are consts
// of string type. We just make a dummy buffer to indicate that this array
// does not rely on external input.
array.GetMutableBuffer<ArrayDataType::kNone>();
switch (dtype) {
case DT_FLOAT:
array.data_type = ArrayDataType::kFloat;
ImportFloatArray(tensor, &array);
break;
case DT_INT32:
array.data_type = ArrayDataType::kInt32;
ImportInt32Array(tensor, &array);
break;
case DT_INT64:
array.data_type = ArrayDataType::kInt64;
ImportInt64Array(tensor, &array);
break;
case DT_STRING:
array.data_type = ArrayDataType::kString;
ImportStringArray(tensor, &array);
break;
default:
array.data_type = ArrayDataType::kNone;
// do nothing, silently ignore the Const data.
// We just make a dummy buffer to indicate that
// this array does not rely on external input.
array.GetMutableBuffer<ArrayDataType::kNone>();
break;
}
}
@ -1191,7 +1222,7 @@ void ConvertGatherOperator(const NodeDef& node,
CHECK_EQ(node.op(), "Gather");
CHECK_EQ(GetInputsCount(node, tf_import_flags), 2);
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
CHECK(indices_data_type == DT_INT32);
CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
auto* op = new GatherOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));

View File

@ -153,7 +153,15 @@ enum class AxesOrder {
// because we'll be dropping the array anyway (e.g. some exotic array types
// may be involved only in debug-only subgraphs that we may not be interested
// in actually supporting).
enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 };
enum class ArrayDataType {
kNone,
kBool,
kFloat,
kUint8,
kInt32,
kInt64,
kString
};
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
template <ArrayDataType A>
@ -182,6 +190,10 @@ template <>
struct DataTypeImpl<ArrayDataType::kInt64> {
typedef int64 Type;
};
template <>
struct DataTypeImpl<ArrayDataType::kString> {
typedef string Type;
};
template <ArrayDataType A>
using DataType = typename DataTypeImpl<A>::Type;

View File

@ -53,6 +53,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
return ::tflite::TensorType_INT32;
case ArrayDataType::kUint8:
return ::tflite::TensorType_UINT8;
case ArrayDataType::kString:
return ::tflite::TensorType_STRING;
default:
// FLOAT32 is filled for unknown data types.
// TODO(ycling): Implement type inference in TF Lite interpreter.
@ -66,6 +68,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
return ArrayDataType::kFloat;
case ::tflite::TensorType_INT32:
return ArrayDataType::kInt32;
case ::tflite::TensorType_STRING:
return ArrayDataType::kString;
case ::tflite::TensorType_UINT8:
return ArrayDataType::kUint8;
default:
@ -82,6 +86,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
return CopyBuffer<ArrayDataType::kFloat>(array, builder);
case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
case ArrayDataType::kString:
return CopyBuffer<ArrayDataType::kString>(array, builder);
case ArrayDataType::kUint8:
return CopyBuffer<ArrayDataType::kUint8>(array, builder);
default:
@ -99,6 +105,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
case ::tflite::TensorType_INT32:
return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
case ::tflite::TensorType_STRING:
return CopyBuffer<ArrayDataType::kString>(buffer, array);
case ::tflite::TensorType_UINT8:
return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
default:

View File

@ -316,6 +316,9 @@ void LogArray(int log_level, const Model& model, const string& name) {
case ArrayDataType::kUint8:
VLOG(log_level) << " Data type: kUint8";
break;
case ArrayDataType::kString:
VLOG(log_level) << " Data type: kString";
break;
default:
VLOG(log_level) << " Data type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")";
@ -334,6 +337,9 @@ void LogArray(int log_level, const Model& model, const string& name) {
case ArrayDataType::kUint8:
VLOG(log_level) << " Final type: kUint8";
break;
case ArrayDataType::kString:
VLOG(log_level) << " Final type: kString";
break;
default:
VLOG(log_level) << " Final type: other (numerical value: "
<< static_cast<int>(array.data_type) << ")";
@ -1253,6 +1259,11 @@ int ElementSize(ArrayDataType data_type) {
return 1;
case ArrayDataType::kInt64:
return 8;
// Usually not critical limitation because strings are only input and/or
// output.
case ArrayDataType::kString:
LOG(FATAL) << "Transient arrays with strings are not supported yet";
return 0;
default:
LOG(FATAL) << "Should not get here.";
return 0;

View File

@ -82,22 +82,6 @@ py_test(
],
)
py_test(
name = "utils_test",
size = "small",
srcs = ["python/saved_model/utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":saved_model_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:signature_constants",
"//tensorflow/python/saved_model:tag_constants",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "UniqueDataset"
summary: "Creates a dataset that contains the unique elements of `input_dataset`."
}

View File

@ -35,25 +35,41 @@ bool IsAdd(const NodeDef& node) {
bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
bool IsAnyDiv(const NodeDef& node) {
return node.op() == "RealDiv" || node.op() == "Div" ||
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
}
bool IsApproximateEqual(const NodeDef& node) {
return node.op() == "ApproximateEqual";
}
bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
bool IsBiasAdd(const NodeDef& node) {
return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
}
bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
@ -92,39 +108,77 @@ bool IsEnter(const NodeDef& node) {
return op == "Enter" || op == "RefEnter";
}
bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
bool IsExit(const NodeDef& node) {
const auto& op = node.op();
return op == "Exit" || op == "RefExit";
}
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
bool IsFusedBatchNormGradV1(const NodeDef& node) {
return node.op() == "FusedBatchNormGrad";
bool IsFusedBatchNormGrad(const NodeDef& node) {
const auto& op = node.op();
return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
}
bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
bool IsIdentity(const NodeDef& node) {
const auto& op = node.op();
return op == "Identity" || op == "RefIdentity";
}
bool IsIdentityN(const NodeDef& node) {
const auto& op = node.op();
return op == "IdentityN";
}
bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
bool IsMatMul(const NodeDef& node) {
const auto& op = node.op();
return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
op == "SparseMatMul";
}
bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
bool IsMerge(const NodeDef& node) {
const auto& op = node.op();
return op == "Merge" || op == "RefMerge";
}
bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
bool IsNextIteration(const NodeDef& node) {
const auto& op = node.op();
return op == "NextIteration" || op == "RefNextIteration";
@ -138,6 +192,12 @@ bool IsPlaceholder(const NodeDef& node) {
op == "PlaceholderWithDefault";
}
bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
bool IsReciprocalGrad(const NodeDef& node) {
@ -209,12 +269,18 @@ bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
bool IsVariable(const NodeDef& node) {
const auto& op = node.op();
return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
op == "VarHandleOp" || op == "ReadVariableOp";
}
bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
namespace {
bool GetBoolAttr(const NodeDef& node, const string& name) {
return node.attr().count(name) > 0 && node.attr().at(name).b();
@ -284,5 +350,10 @@ bool IsValuePreserving(const NodeDef& node) {
return value_preserving_ops.count(node.op()) > 0;
}
bool HasOpDef(const NodeDef& node) {
const OpDef* op_def = nullptr;
return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
}
} // namespace grappler
} // end namespace tensorflow

View File

@ -24,11 +24,18 @@ namespace grappler {
bool IsAdd(const NodeDef& node);
bool IsAddN(const NodeDef& node);
bool IsAngle(const NodeDef& node);
bool IsAnyDiv(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node);
bool IsAtan2(const NodeDef& node);
bool IsBiasAdd(const NodeDef& node);
bool IsBiasAddGrad(const NodeDef& node);
bool IsBitcast(const NodeDef& node);
bool IsComplex(const NodeDef& node);
bool IsComplexAbs(const NodeDef& node);
bool IsConj(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
@ -41,18 +48,38 @@ bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
bool IsExit(const NodeDef& node);
bool IsFloorDiv(const NodeDef& node);
bool IsFloorMod(const NodeDef& node);
bool IsFusedBatchNormGradV1(const NodeDef& node);
bool IsFusedBatchNormGrad(const NodeDef& node);
bool IsGreater(const NodeDef& node);
bool IsGreaterEqual(const NodeDef& node);
bool IsIdentity(const NodeDef& node);
bool IsIdentityN(const NodeDef& node);
bool IsIgamma(const NodeDef& node);
bool IsIgammac(const NodeDef& node);
bool IsImag(const NodeDef& node);
bool IsInvGrad(const NodeDef& node);
bool IsLess(const NodeDef& node);
bool IsLessEqual(const NodeDef& node);
bool IsLogicalAnd(const NodeDef& node);
bool IsLogicalNot(const NodeDef& node);
bool IsLogicalOr(const NodeDef& node);
bool IsMaximum(const NodeDef& node);
bool IsMerge(const NodeDef& node);
bool IsMinimum(const NodeDef& node);
bool IsMod(const NodeDef& node);
bool IsMul(const NodeDef& node);
bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node);
bool IsPad(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
bool IsPolygamma(const NodeDef& node);
bool IsPow(const NodeDef& node);
bool IsReal(const NodeDef& node);
bool IsRealDiv(const NodeDef& node);
bool IsRelu6Grad(const NodeDef& node);
bool IsReluGrad(const NodeDef& node);
@ -80,7 +107,10 @@ bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
bool IsTruncateMod(const NodeDef& node);
bool IsVariable(const NodeDef& node);
bool IsZeta(const NodeDef& node);
// Return true if the op is an aggregation (e.g. Add, AddN).
// Returns false if it could not be determined to be so.
@ -102,6 +132,9 @@ bool IsInvolution(const NodeDef& node);
// function returns true if the op commutes with all element-wise operations.
bool IsValuePreserving(const NodeDef& node);
// Returns true if we can find an opdef corresponding to the op of the node.
bool HasOpDef(const NodeDef& node);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/node_def.pb.h"
@ -350,15 +351,16 @@ Status DependencyOptimizer::TransitiveReduction() {
num_nodes);
for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
const NodeDef& node = optimized_graph_->node(node_idx);
if (ModifiesFrameInfo(node)) {
// Ignore nodes that modify frame info.
if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
// Ignore function nodes and nodes that modify frame info.
continue;
}
for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
const string& input = node.input(input_slot);
const NodeDef* input_node = node_map_->GetNode(input);
if (ModifiesFrameInfo(*input_node)) {
// Ignore edges from nodes that modify frame info.
if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
// Ignore edges from nodes that modify frame info and from Merge nodes,
// because we cannot know which of it's input paths executes.
continue;
}
const int input_node_idx = node_to_idx_[input_node];
@ -375,6 +377,14 @@ Status DependencyOptimizer::TransitiveReduction() {
// of length > 1, we can drop that control dependency.
int num_controls_removed = 0;
std::vector<int> longest_distance(num_nodes);
// Map from target_index -> set of (input_slot, source_index), representing
// the control edges to remove. We sort them in reverse order by input slot,
// such that when we swap them out so we don't clobber the
// node(target).input() repeated field.
typedef std::pair<int, int> InputSlotAndSource;
std::unordered_map<
int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
control_edges_to_remove;
for (int source = 0; source < num_nodes; ++source) {
int highest_control_target = -1;
for (const auto& control_output : control_outputs[source]) {
@ -382,7 +392,7 @@ Status DependencyOptimizer::TransitiveReduction() {
highest_control_target = control_output.first;
}
}
if (highest_control_target < source) {
if (highest_control_target <= source) {
continue;
}
std::fill(longest_distance.begin() + source,
@ -391,7 +401,10 @@ Status DependencyOptimizer::TransitiveReduction() {
for (int input : inputs[target]) {
// If the input node is before source in the topo order, no path
// source -> input -> target can exits and we can skip it.
if (input >= source) {
// Also only extend a path from the source itself or from nodes that
// have a path from source, indicated by longest_distance[input] > 0.
if (input == source ||
(input > source && longest_distance[input] > 0)) {
// If source -> input -> target is longer than the longest
// path so far from source -> target, update the longest_distance.
int candidate_longest_distance = longest_distance[input] + 1;
@ -402,25 +415,36 @@ Status DependencyOptimizer::TransitiveReduction() {
}
}
// If the longest path from the source to the target of a control dependency
// is longer than 1, there exists an alternate path, and we can eliminate
// the control dependency since it is redundant.
// If the longest path from source to target of a control dependency is
// longer than 1, there exists an alternate path, and we can eliminate the
// redundant direct control dependency.
for (const auto& control_output : control_outputs[source]) {
const int target = control_output.first;
if (longest_distance[target] > 1) {
const int input_slot = control_output.second;
// We modify the node inplace here. This is safe because there can
// only be one control edge from a given source to a given target.
const NodeDef& source_node = optimized_graph_->node(source);
NodeDef* target_node = optimized_graph_->mutable_node(target);
target_node->mutable_input()->SwapElements(
input_slot, target_node->input_size() - 1);
node_map_->RemoveOutput(source_node.name(), target_node->name());
target_node->mutable_input()->RemoveLast();
++num_controls_removed;
control_edges_to_remove[target].emplace(input_slot, source);
VLOG(1) << "Removing edge from:\n"
<< optimized_graph_->node(source).DebugString() << "\n\nto:\n\n"
<< optimized_graph_->node(target).DebugString();
}
}
}
for (const auto& it : control_edges_to_remove) {
const int target = it.first;
NodeDef* target_node = optimized_graph_->mutable_node(target);
for (const InputSlotAndSource& slot_and_source : it.second) {
const int input_slot = slot_and_source.first;
const int source = slot_and_source.second;
const NodeDef& source_node = optimized_graph_->node(source);
CHECK_LT(input_slot, target_node->input_size());
target_node->mutable_input()->SwapElements(input_slot,
target_node->input_size() - 1);
node_map_->RemoveOutput(source_node.name(), target_node->name());
target_node->mutable_input()->RemoveLast();
++num_controls_removed;
}
}
VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
<< " control dependencies";
return Status::OK();
@ -442,36 +466,27 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty();
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
CleanControlInputs();
const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1;
const int num_iterations = 2;
for (int iteration = 0; iteration < num_iterations; ++iteration) {
Status topo_sort_status;
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
// Prepare the graph for transitive reduction if enabled.
topo_sort_status = TopologicalSort(optimized_graph_);
}
// Perform topological sort to prepare the graph for transitive reduction.
topo_sort_status = TopologicalSort(optimized_graph_);
// Set up index-based graph datastructures to speed up analysis steps below.
node_map_.reset(new NodeMap(optimized_graph_));
BuildNodeToIdx();
// Remove redundant control dependencies, iteration 1.
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
if (topo_sort_status.ok()) {
TF_RETURN_IF_ERROR(TransitiveReduction());
} else {
LOG(ERROR) << topo_sort_status.error_message();
}
VLOG(1) << "Graph after transitive reduction:\n"
<< optimized_graph_->DebugString();
if (topo_sort_status.ok()) {
// Remove redundant control dependencies.
TF_RETURN_IF_ERROR(TransitiveReduction());
} else {
LOG(ERROR) << topo_sort_status.error_message();
}
// Turn nodes without non-control outputs into NoOps, prune NoOps.
// Turn nodes with only control outputs into NoOps, prune NoOps.
TF_RETURN_IF_ERROR(OptimizeDependencies());
VLOG(1) << "Graph after NoOp conversion & pruning:\n"
<< optimized_graph_->DebugString();
}
VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString();
return Status::OK();
}

View File

@ -157,7 +157,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@ -228,6 +228,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) {
// The optimization should be disabled to prevent increasing the number of
// nodes crossing device boundaries.
TF_CHECK_OK(TopologicalSort(&item.graph));
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
@ -282,7 +283,7 @@ TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch.push_back("id2");
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
DependencyOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);

View File

@ -60,7 +60,9 @@ std::set<string> GetOpsFormatSupported() {
"DepthwiseConv2dNativeBackpropInput",
"DepthwiseConv2dNativeBackpropFilter",
"FusedBatchNorm",
"FusedBatchNormV2",
"FusedBatchNormGrad",
"FusedBatchNormGradV2",
"FusedConv2DBiasActivation",
"MaxPool",
"MaxPoolGrad",
@ -75,52 +77,77 @@ std::set<string> GetOpsFormatAgnostic() {
std::set<string> ops_format_agnostic = {"Abs",
"Add",
"AddN",
"AddV2",
"Acos",
"Acosh",
"Angle",
"ApproximateEqual",
"Asin",
"Asinh",
"Atan",
"Atan2",
"Atanh",
"Bitcast",
"Cast",
"Ceil",
"CheckNumerics",
"Cos",
"Cosh",
"Complex",
"ComplexAbs",
"Concat",
"ConcatV2",
"Conj",
"Cos",
"Cosh",
"Digamma",
"Div",
"Elu",
"EluGrad",
"Equal",
"Erf",
"Erfc",
"Exp",
"Expm1",
"Floor",
"FloorDiv",
"FloorMod",
"Greater",
"GreaterEqual",
"GuaranteeConst",
"Identity",
"IdentityN",
"Igamma",
"Igammac",
"Imag",
"Inv",
"InvGrad",
"IsFinite",
"IsInf",
"IsNan",
"Less",
"LessEqual",
"Lgamma",
"Log",
"LogicalAnd",
"LogicalNot",
"LogicalOr",
"Log1p",
"Maximum",
"Merge",
"Minimum",
"Mod",
"Mul",
"Neg",
"NotEqual",
"OnesLike",
"Pad",
"PreventGradient",
"Polygamma",
"Pow",
"Real",
"RealDiv",
"Reciprocal",
"ReciprocalGrad",
"RefIdentity",
"Relu",
"Relu6",
"Relu6Grad",
@ -141,7 +168,8 @@ std::set<string> GetOpsFormatAgnostic() {
"SoftplusGrad",
"Split",
"Switch",
"RefIdentity",
"TruncateDiv",
"TruncateMod",
"RefMerge",
"RefSwitch",
"Round",
@ -157,7 +185,8 @@ std::set<string> GetOpsFormatAgnostic() {
"Tan",
"Tanh",
"TanhGrad",
"ZerosLike"};
"ZerosLike",
"Zeta"};
return ops_format_agnostic;
}
@ -212,6 +241,28 @@ bool IsUnaryGrad(const NodeDef& node) {
return is_unary_grad;
}
bool IsComparisonOp(const NodeDef& node) {
bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
IsLessEqual(node) || IsNotEqual(node);
return is_compare;
}
bool IsLogicalOp(const NodeDef& node) {
return IsLogicalAnd(node) || IsLogicalNot(node) || IsLogicalOr(node);
}
bool IsBinaryOp(const NodeDef& node) {
bool is_binary =
IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
return is_binary;
}
class GraphProcessor {
public:
GraphProcessor(const VirtualPlacer& virtual_placer,
@ -409,17 +460,19 @@ class NodeProcessor : public GraphProcessor {
virtual void UpdateAttrShape() {
if (node_->attr().find("_output_shapes") != node_->attr().end()) {
auto shape = node_->mutable_attr()
->at("_output_shapes")
.mutable_list()
->mutable_shape(0);
if (shape->dim_size() == 4) {
int64 h = shape->dim(1).size();
int64 w = shape->dim(2).size();
int64 c = shape->dim(3).size();
shape->mutable_dim(1)->set_size(c);
shape->mutable_dim(2)->set_size(h);
shape->mutable_dim(3)->set_size(w);
for (const auto& pos : GetOutputPos()) {
auto shape = node_->mutable_attr()
->at("_output_shapes")
.mutable_list()
->mutable_shape(pos);
if (shape->dim_size() == 4) {
int64 h = shape->dim(1).size();
int64 w = shape->dim(2).size();
int64 c = shape->dim(3).size();
shape->mutable_dim(1)->set_size(c);
shape->mutable_dim(2)->set_size(h);
shape->mutable_dim(3)->set_size(w);
}
}
}
}
@ -603,13 +656,15 @@ class NodeProcessor : public GraphProcessor {
added_node_name = AddPrefixToNodeName(added_node_base_name,
kTransposeNCHWToNHWC, "-");
DataType dtype;
if (op == "Imag" || op == "Real" || op == "Angle" ||
op == "Conj" || op == "ComplexAbs") {
if (IsAngle(*node_) || IsComplex(*node_) ||
IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout"));
dtype = node_->attr().at("Tout").type();
} else if (op == "Bitcast") {
} else if (IsBitcast(*node_)) {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "type"));
dtype = node_->attr().at("type").type();
} else if (IsLogicalOp(*node_) || IsComparisonOp(*node_)) {
dtype = DT_BOOL;
} else {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
dtype = node_->attr().at("T").type();
@ -617,7 +672,8 @@ class NodeProcessor : public GraphProcessor {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
AddNodeTranspose(
added_node_name, input, const_name, dtype,
node_->attr().at("_output_shapes").list().shape(0), false);
node_->attr().at("_output_shapes").list().shape(input_port),
false);
} else if (op == "DataFormatVecPermute") {
added_node_name = AddPrefixToNodeName(added_node_base_name,
kVecPermuteNCHWToNHWC, "-");
@ -1002,11 +1058,10 @@ class AgnosticNodeProcessor : public NodeProcessor {
if (IsConcatV1(node)) {
return {1};
}
if (IsAdd(node) || IsMul(node) || IsRealDiv(node) ||
IsSquaredDifference(node) || IsSub(node)) {
if (IsBinaryOp(node) || IsUnaryGrad(node)) {
return {0, 1};
}
if (IsShapeN(node)) {
if (IsShapeN(node) || IsIdentityN(node)) {
std::vector<int> pos;
for (int i = 0; i < node.input_size(); i++) {
pos.push_back(i);
@ -1207,6 +1262,40 @@ class ConcatProcessor : public AgnosticNodeProcessor {
int axis_node_pos_;
};
class IdentityNProcessor : public AgnosticNodeProcessor {
public:
explicit IdentityNProcessor(const OptimizeContext& opt_cxt)
: AgnosticNodeProcessor(opt_cxt) {}
protected:
bool ShouldProcess() const override {
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsOnGPU();
}
std::vector<int> GetInputPos() const override {
std::vector<int> input_pos;
for (int i = 0; i < node_->input_size(); i++) {
auto input = node_map_->GetNode(node_->input(i));
int port;
ParseNodeName(node_->input(i), &port);
if (IsPortDimsFour(*input, port) &&
(IsNodeAfterNCHWToNHWC(*input) || IsNodeNCHWToNHWC(input->name()))) {
input_pos.push_back(i);
}
}
return input_pos;
}
std::set<int> GetOutputPos() const override {
std::set<int> output_pos{};
for (const auto& input_pos : GetInputPos()) {
output_pos.insert(input_pos);
}
return output_pos;
}
};
class MergeProcessor : public AgnosticNodeProcessor {
public:
explicit MergeProcessor(const OptimizeContext& opt_cxt)
@ -1371,12 +1460,15 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
bool IsInputConvertible() const {
int input_port;
auto input = node_map_->GetNode(node_->input(0));
ParseNodeName(node_->input(0), &input_port);
if (IsNodeNCHWToNHWC(input->name())) {
input = node_map_->GetNode(input->input(0));
ParseNodeName(input->input(0), &input_port);
}
if (input->attr().find("_output_shapes") != input->attr().end()) {
auto shape = input->attr().at("_output_shapes").list().shape(0);
auto shape = input->attr().at("_output_shapes").list().shape(input_port);
if (shape.dim_size() != 4) {
return false;
}
@ -1529,7 +1621,7 @@ class DataLayoutOptimizer : GraphProcessor {
new Conv2DBackpropFilterProcessor(opt_cxt, true));
} else if (IsDepthwiseConv2dNativeBackpropInput(*node)) {
node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true));
} else if (IsFusedBatchNormGradV1(*node)) {
} else if (IsFusedBatchNormGrad(*node)) {
node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt));
} else if (IsMaxPoolGradV1(*node)) {
node_processor.reset(new MaxPoolGradProcessor(opt_cxt));
@ -1557,11 +1649,12 @@ class DataLayoutOptimizer : GraphProcessor {
std::unique_ptr<NodeProcessor> node_processor;
if (IsAddN(*node)) {
node_processor.reset(new AddNProcessor(opt_cxt));
} else if (IsAdd(*node) || IsMul(*node) || IsRealDiv(*node) ||
IsSquaredDifference(*node) || IsSub(*node)) {
} else if (IsBinaryOp(*node)) {
node_processor.reset(new BinaryOpProcessor(opt_cxt));
} else if (IsConcat(*node)) {
node_processor.reset(new ConcatProcessor(opt_cxt));
} else if (IsIdentityN(*node)) {
node_processor.reset(new IdentityNProcessor(opt_cxt));
} else if (IsMerge(*node)) {
node_processor.reset(new MergeProcessor(opt_cxt));
} else if (IsPad(*node)) {

View File

@ -1066,6 +1066,48 @@ TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
"LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-1");
}
TEST_F(LayoutOptimizerTest, Complex) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
auto i = ops::Identity(s.WithOpName("i"), comp);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto merge_node = node_map.GetNode("complex");
EXPECT_EQ(merge_node->input(0), "Conv2D");
EXPECT_EQ(merge_node->input(1), "Conv2D");
auto trans =
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-complex-0-0");
EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
}
TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto i = node_map.GetNode("identity_n");
EXPECT_EQ(i->input(0), "vector");
EXPECT_EQ(i->input(1), "Conv2D");
auto trans =
node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
EXPECT_EQ(trans->input(0), "identity_n:1");
auto add_node = node_map.GetNode("add");
EXPECT_EQ(add_node->input(0), "identity_n");
EXPECT_EQ(add_node->input(1),
"LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1");
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -314,6 +314,11 @@ Status CudaSolver::forward_input_or_allocate_scoped_tensor(
// are sometimes inaccurate, e.g., are missing 'const' on pointers
// to immutable arguments, while the actual headers have them as expected.
// Check the actual declarations in the cusolver_api.h header file.
//
// NOTE: The cuSolver functions called below appear not to be threadsafe.
// so we put a global lock around the calls. Since these functions only put a
// kernel on the shared stream, it is not a big performance hit.
// TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
//=============================================================================
template <typename Scalar, typename SolverFnT>
@ -324,6 +329,7 @@ static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
const Scalar* A, int lda,
const Scalar* beta, /* host or device pointer */
const Scalar* B, int ldb, Scalar* C, int ldc) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
reinterpret_cast<const CudaScalar*>(alpha),
@ -355,6 +361,7 @@ static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle,
cublasFillMode_t uplo, int n, Scalar* A, int lda,
int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@ -387,6 +394,7 @@ static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, int* dev_pivots,
int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@ -419,9 +427,6 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
cublasOperation_t trans, int n, int nrhs,
const Scalar* A, int lda, const int* pivots,
Scalar* B, int ldb, int* dev_lapack_info) {
// Note: The cuSolver functions called here appear not to be threadsafe.
// so we put a global lock around it. Since this function only puts a
// kernel on the stream, it is not a big performance hit.
mutex_lock lock(handle_map_mutex);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
@ -449,6 +454,7 @@ static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, Scalar* A, int lda, Scalar* tau,
int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@ -483,6 +489,7 @@ static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
int m, int n, int k, const Scalar* dev_a,
int lda, const Scalar* dev_tau, Scalar* dev_c,
int ldc, int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(
@ -526,6 +533,7 @@ static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
cusolverDnHandle_t cusolver_dn_handle, int m,
int n, int k, Scalar* dev_a, int lda,
const Scalar* dev_tau, int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
@ -606,17 +614,13 @@ static inline Status GesvdImpl(
OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
mutex_lock lock(handle_map_mutex);
/* Get amount of workspace memory required. */
int lwork;
TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
/* Allocate device memory for workspace. */
auto dev_workspace =
cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
// Note: The cuSolver functions called here appear not to be threadsafe.
// so we put a global lock around it. Since this function only puts a
// kernel on the stream, it is not a big performance hit.
mutex_lock lock(handle_map_mutex);
/* Launch the solver kernel. */
TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
CUDAComplex(A), lda, S, CUDAComplex(U),
ldu, CUDAComplex(VT), ldvt,
@ -655,6 +659,7 @@ static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
int lda, int* dev_pivots,
DeviceLapackInfo* dev_lapack_info,
int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@ -689,6 +694,7 @@ static inline Status GetrsBatchedImpl(
const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
const Scalar* const host_b_dev_ptrs[], int ldb,
DeviceLapackInfo* dev_lapack_info, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@ -734,6 +740,7 @@ static inline Status GetriBatchedImpl(
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
@ -776,6 +783,7 @@ static inline Status MatInvBatchedImpl(
cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
DeviceLapackInfo* dev_lapack_info, int batch_size) {
mutex_lock lock(handle_map_mutex);
using CudaScalar = typename CUDAComplexT<Scalar>::type;
ScratchSpace<uint8> dev_a_dev_ptrs =
cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",

View File

@ -493,6 +493,18 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "dataset_ops",
deps = [
@ -526,6 +538,7 @@ tf_kernel_library(
":take_dataset_op",
":tensor_dataset_op",
":tensor_slice_dataset_op",
":unique_dataset_op",
":zip_dataset_op",
],
)

View File

@ -55,10 +55,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
*output = nullptr;
#define HANDLE_TYPE(T) \
case DataTypeToEnum<T>::value: { \
*output = new Dataset<T>(batch_size, row_shape, input); \
break; \
#define HANDLE_TYPE(T) \
case DataTypeToEnum<T>::value: { \
*output = new Dataset<T>(ctx, batch_size, row_shape, input); \
break; \
}
switch (input->output_dtypes()[0]) {
@ -75,11 +75,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
private:
// TODO(mrry): Push the templated code down to the raw copying routine.
template <class T>
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
Dataset(int64 batch_size, const PartialTensorShape& row_shape,
const DatasetBase* input)
: batch_size_(batch_size), row_shape_(row_shape), input_(input) {
Dataset(OpKernelContext* ctx, int64 batch_size,
const PartialTensorShape& row_shape, const DatasetBase* input)
: GraphDatasetBase(ctx),
batch_size_(batch_size),
row_shape_(row_shape),
input_(input) {
input_->Ref();
output_shapes_.reserve(3);
@ -112,6 +115,25 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
")::Dataset");
}
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_node;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
Node* batch_size_node;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
Node* row_shape_node;
std::vector<int64> row_shape;
row_shape.reserve(
row_shape_.dims()); // not an unknown rank PartialTensorShape
for (int i = 0; i < row_shape_.dims(); i++)
row_shape.emplace_back(row_shape_.dim_size(i));
TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_node, batch_size_node, row_shape_node}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset<T>> {
public:
@ -242,6 +264,20 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(Iterator::SaveParent(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(Iterator::RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);

View File

@ -64,20 +64,23 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
std::move(other_arguments),
&captured_func));
*output =
new Dataset(input, std::move(initial_state), std::move(captured_func),
state_types_, output_types_, output_shapes_);
*output = new Dataset(ctx, input, func_, std::move(initial_state),
std::move(captured_func), state_types_, output_types_,
output_shapes_);
}
private:
class Dataset : public DatasetBase {
class Dataset : public GraphDatasetBase {
public:
Dataset(const DatasetBase* input, std::vector<Tensor> initial_state,
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func, std::vector<Tensor> initial_state,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& state_types,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: input_(input),
: GraphDatasetBase(ctx),
input_(input),
func_(func),
initial_state_(std::move(initial_state)),
captured_func_(std::move(captured_func)),
state_types_(state_types),
@ -103,6 +106,45 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "ScanDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
std::vector<Node*> initial_state_nodes;
initial_state_nodes.reserve(initial_state_.size());
for (const Tensor& t : initial_state_) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
initial_state_nodes.emplace_back(node);
}
std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments.emplace_back(node);
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
b->BuildAttrValue(func_, &f);
AttrValue state_types;
b->BuildAttrValue(state_types_, &state_types);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {{0, input_node}},
{{1, initial_state_nodes}, {2, other_arguments}},
{{"f", f},
{"Tstate", state_types},
{"Targuments", other_arguments_types_attr}},
output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@ -185,6 +227,38 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
return s;
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
if (!state_.empty()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("state_size"), state_.size()));
for (int idx = 0; idx < state_.size(); idx++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("state[", idx, "]")), state_[idx]));
}
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
if (reader->Contains(full_name("state_size"))) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("state_size"), &size));
state_.resize(size);
for (int idx = 0; idx < size; idx++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("state[", idx, "]")), &state_[idx]));
}
}
return Status::OK();
}
private:
mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
@ -192,6 +266,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
const NameAttrList func_;
const std::vector<Tensor> initial_state_;
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector state_types_;

View File

@ -0,0 +1,219 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
class UniqueDatasetOp : public UnaryDatasetOpKernel {
public:
explicit UniqueDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
errors::InvalidArgument("UniqueDataset only supports "
"inputs with a single component."));
DataType input_dtype = input->output_dtypes()[0];
OP_REQUIRES(ctx,
input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
input_dtype == DT_STRING,
errors::InvalidArgument(
"UniqueDataset only supports inputs with a single "
"`tf.int32`, `tf.int64`, or `tf.string` component."));
*output = new Dataset(ctx, input);
}
private:
class Dataset : public GraphDatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input)
: GraphDatasetBase(ctx), input_(input) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIterator(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Unique")}));
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() override {
return strings::StrCat("UniqueDatasetOp::Dataset");
}
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const typename Iterator::Params& params)
: DatasetIterator<Dataset>(params),
input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
bool saw_new_value;
do {
saw_new_value = false;
out_tensors->clear();
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (*end_of_sequence) {
break;
}
DCHECK_EQ(1, out_tensors->size());
saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
} while (!saw_new_value);
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impl_empty"), ""));
}
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("unique_elements_size"), unique_elements_.size()));
size_t i = 0;
for (const Tensor& t : unique_elements_) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("unique_elements[", i++, "]")), t));
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
int64 num_unique_elements;
unique_elements_.clear();
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
&num_unique_elements));
for (int64 i = 0; i < num_unique_elements; ++i) {
Tensor unique_element;
TF_RETURN_IF_ERROR(reader->ReadTensor(
full_name(strings::StrCat("unique_elements[", i, "]")),
&unique_element));
auto insert_result = unique_elements_.insert(unique_element);
if (!insert_result.second) {
return errors::InvalidArgument(
"Checkpoint contained two unique elements with the same "
"value.");
}
}
return Status::OK();
}
private:
struct TensorHash {
size_t operator()(const Tensor& t) const {
if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
return Hash64(t.tensor_data().data(), t.tensor_data().size());
} else {
DCHECK_EQ(DT_STRING, t.dtype());
auto flat_t = t.flat<string>();
uint64 hash = 0;
for (int64 i = 0; i < t.NumElements(); ++i) {
hash = Hash64Combine(hash, Hash64(flat_t(i)));
}
return static_cast<size_t>(hash);
}
}
};
struct TensorKeyEqual {
bool operator()(const Tensor& lhs, const Tensor& rhs) const {
if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
return false;
}
switch (lhs.dtype()) {
#define HANDLE_TYPE(T) \
case T: \
do { \
auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
for (int64 i = 0; i < lhs.NumElements(); ++i) { \
if (lhs_flat(i) != rhs_flat(i)) { \
return false; \
} \
} \
return true; \
} while (0)
HANDLE_TYPE(DT_INT32);
HANDLE_TYPE(DT_INT64);
HANDLE_TYPE(DT_STRING);
default:
LOG(FATAL) << "UniqueDataset unhandled data type: "
<< DataTypeString(lhs.dtype());
}
}
};
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
GUARDED_BY(mu_);
};
const DatasetBase* const input_;
};
};
REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
} // namespace tensorflow

View File

@ -558,6 +558,16 @@ filename: A path on the filesystem where we should cache the dataset. Note: this
will be a directory.
)doc");
REGISTER_OP("UniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Creates a dataset that contains the unique elements of `input_dataset`.
)doc");
REGISTER_OP("TextLineDataset")
.Input("filenames: string")
.Input("compression_type: string")

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/lib/core/errors.h"
@ -327,6 +329,57 @@ Status CurlHttpRequest::SetResultBuffer(std::vector<char>* out_buffer) {
return Status::OK();
}
Status CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) {
CHECK(buffer != nullptr);
TF_RETURN_IF_ERROR(CheckInitialized());
TF_RETURN_IF_ERROR(CheckNotSent());
direct_response_ = DirectResponseState{buffer, size, 0};
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
reinterpret_cast<void*>(this));
libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION,
&CurlHttpRequest::WriteCallbackDirect);
return Status::OK();
}
size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size,
size_t nmemb, void* userdata) {
CHECK(ptr != nullptr);
auto that = reinterpret_cast<CurlHttpRequest*>(userdata);
DirectResponseState* state = &that->direct_response_;
CHECK(state->buffer_ != nullptr);
CHECK(state->bytes_transferred_ <= state->buffer_size_);
size_t curl_bytes_received = size * nmemb;
size_t user_buffer_bytes_available =
state->buffer_size_ - state->bytes_transferred_;
// The HTTP server may send a response body that is longer than what we
// expected. We must not use CHECK() for this situation, because that would
// imply a code bug (in this client code) where none exists; the violation of
// expectations would have been caused by the server, not the client. So we
// report a log warning, if an HTTP server is misbehaving.
if (curl_bytes_received > user_buffer_bytes_available) {
LOG(WARNING) << "The HTTP response body that we received is longer than we "
"requested or expected. "
<< "Total bytes requested: " << state->buffer_size_
<< " Bytes received (so far) in HTTP response body: "
<< (state->bytes_transferred_ + curl_bytes_received);
}
size_t bytes_to_copy =
std::min<size_t>(curl_bytes_received, user_buffer_bytes_available);
memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy);
state->bytes_transferred_ += bytes_to_copy;
return bytes_to_copy;
}
size_t CurlHttpRequest::GetResultBufferDirectBytesTransferred() {
CHECK(direct_response_.buffer_ != nullptr);
return direct_response_.bytes_transferred_;
}
Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity,
uint32 total) {
TF_RETURN_IF_ERROR(CheckInitialized());

View File

@ -103,6 +103,26 @@ class CurlHttpRequest : public HttpRequest {
/// read. Existing content of the vector will be cleared.
Status SetResultBuffer(std::vector<char>* out_buffer) override;
/// \brief Specifies the buffer for receiving the response body, when the
/// caller knows the maximum size of the response body.
///
/// This method allows the caller to receive the response body without an
/// additional intermediate buffer allocation and copy. This method should
/// be called before calling Send(). After Send() has succeeded, the caller
/// should use the GetResultBufferDirectBytesTransferred() method in order
/// to learn how many bytes were transferred.
///
/// Using this method is mutually exclusive with using SetResultBuffer().
Status SetResultBufferDirect(char* buffer, size_t size) override;
/// \brief Returns the number of bytes (of the response body) that were
/// transferred, when using the SetResultBufferDirect() method. The returned
/// value will always be less than or equal to the 'size' parameter that
/// was passed to SetResultBufferDirect(). If the actual HTTP response body
/// was greater than 'size' bytes, then this transfer method will only copy
/// the first 'size' bytes, and the rest will be ignored.
size_t GetResultBufferDirectBytesTransferred() override;
/// \brief Returns the response headers of a completed request.
///
/// If the header is not found, returns an empty string.
@ -127,6 +147,10 @@ class CurlHttpRequest : public HttpRequest {
/// A write callback in the form which can be accepted by libcurl.
static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
void* userdata);
/// Processes response body content received when using SetResultBufferDirect.
static size_t WriteCallbackDirect(const void* ptr, size_t size, size_t nmemb,
void* userdata);
/// A read callback in the form which can be accepted by libcurl.
static size_t ReadCallback(void* ptr, size_t size, size_t nmemb,
FILE* userdata);
@ -150,6 +174,14 @@ class CurlHttpRequest : public HttpRequest {
size_t post_body_read_ = 0;
std::vector<char>* response_buffer_ = nullptr;
struct DirectResponseState {
char* buffer_;
size_t buffer_size_;
size_t bytes_transferred_;
};
DirectResponseState direct_response_ = {};
CURL* curl_ = nullptr;
curl_slist* curl_headers_ = nullptr;
curl_slist* resolve_list_ = nullptr;

View File

@ -288,6 +288,39 @@ TEST(CurlHttpRequestTest, GetRequest) {
EXPECT_EQ(200, http_request.GetResponseCode());
}
TEST(CurlHttpRequestTest, GetRequest_Direct) {
FakeLibCurl libcurl("get response", 200);
CurlHttpRequest http_request(&libcurl);
TF_EXPECT_OK(http_request.Init());
std::vector<char> scratch(100, 0);
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
TF_EXPECT_OK(http_request.SetRange(100, 199));
TF_EXPECT_OK(
http_request.SetResultBufferDirect(scratch.data(), scratch.capacity()));
TF_EXPECT_OK(http_request.Send());
string expected_response = "get response";
size_t response_bytes_transferred =
http_request.GetResultBufferDirectBytesTransferred();
EXPECT_EQ(response_bytes_transferred, expected_response.size());
EXPECT_EQ(
"get response",
string(scratch.begin(), scratch.begin() + response_bytes_transferred));
// Check interactions with libcurl.
EXPECT_TRUE(libcurl.is_initialized_);
EXPECT_EQ("http://www.testuri.com", libcurl.url_);
EXPECT_EQ("100-199", libcurl.range_);
EXPECT_EQ("", libcurl.custom_request_);
EXPECT_EQ(1, libcurl.headers_->size());
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl.headers_)[0]);
EXPECT_FALSE(libcurl.is_post_);
EXPECT_EQ(200, http_request.GetResponseCode());
}
TEST(CurlHttpRequestTest, GetRequest_Empty) {
FakeLibCurl libcurl("", 200);
CurlHttpRequest http_request(&libcurl);

View File

@ -123,8 +123,12 @@ Status FileBlockCache::MaybeFetch(const Key& key,
case FetchState::CREATED:
block->state = FetchState::FETCHING;
block->mu.unlock(); // Release the lock while making the API call.
status.Update(
block_fetcher_(key.first, key.second, block_size_, &block->data));
block->data.clear();
block->data.resize(block_size_, 0);
size_t bytes_transferred;
status.Update(block_fetcher_(key.first, key.second, block_size_,
block->data.data(), &bytes_transferred));
block->data.resize(bytes_transferred, 0);
block->mu.lock(); // Reacquire the lock immediately afterwards
if (status.ok()) {
downloaded_block = true;
@ -150,15 +154,15 @@ Status FileBlockCache::MaybeFetch(const Key& key,
}
Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
out->clear();
char* buffer, size_t* bytes_transferred) {
*bytes_transferred = 0;
if (n == 0) {
return Status::OK();
}
if (block_size_ == 0 || max_bytes_ == 0) {
// The cache is effectively disabled, so we pass the read through to the
// fetcher without breaking it up into blocks.
return block_fetcher_(filename, offset, n, out);
return block_fetcher_(filename, offset, n, buffer, bytes_transferred);
}
// Calculate the block-aligned start and end of the read.
size_t start = block_size_ * (offset / block_size_);
@ -166,6 +170,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
if (finish < offset + n) {
finish += block_size_;
}
size_t total_bytes_transferred = 0;
// Now iterate through the blocks, reading them one at a time.
for (size_t pos = start; pos < finish; pos += block_size_) {
Key key = std::make_pair(filename, pos);
@ -181,6 +186,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
// The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`.
*bytes_transferred = total_bytes_transferred;
return errors::OutOfRange("EOF at offset ", offset, " in file ", filename,
" at position ", pos, "with data size ",
data.size());
@ -196,13 +202,16 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
end -= (pos + data.size()) - (offset + n);
}
if (begin < end) {
out->insert(out->end(), begin, end);
size_t bytes_to_copy = end - begin;
memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy);
total_bytes_transferred += bytes_to_copy;
}
if (data.size() < block_size_) {
// The block was a partial block and thus signals EOF at its upper bound.
break;
}
}
*bytes_transferred = total_bytes_transferred;
return Status::OK();
}

View File

@ -43,8 +43,9 @@ class FileBlockCache {
/// cache is constructed. The returned Status should be OK as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
typedef std::function<Status(const string&, size_t, size_t,
std::vector<char>*)>
typedef std::function<Status(const string& filename, size_t offset,
size_t buffer_size, char* buffer,
size_t* bytes_transferred)>
BlockFetcher;
FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness,
@ -83,8 +84,8 @@ class FileBlockCache {
/// placed in `out`.
/// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
/// in `out`).
Status Read(const string& filename, size_t offset, size_t n,
std::vector<char>* out);
Status Read(const string& filename, size_t offset, size_t n, char* buffer,
size_t* bytes_transferred);
/// Remove all cached blocks for `filename`.
void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_);

View File

@ -25,6 +25,18 @@ limitations under the License.
namespace tensorflow {
namespace {
Status ReadCache(FileBlockCache* cache, const string& filename, size_t offset,
size_t n, std::vector<char>* out) {
out->clear();
out->resize(n, 0);
size_t bytes_transferred = 0;
Status status =
cache->Read(filename, offset, n, out->data(), &bytes_transferred);
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
return status;
}
TEST(FileBlockCacheTest, PassThrough) {
const string want_filename = "foo/bar";
const size_t want_offset = 42;
@ -32,12 +44,13 @@ TEST(FileBlockCacheTest, PassThrough) {
int calls = 0;
auto fetcher = [&calls, want_filename, want_offset, want_n](
const string& got_filename, size_t got_offset,
size_t got_n, std::vector<char>* out) {
size_t got_n, char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(got_filename, want_filename);
EXPECT_EQ(got_offset, want_offset);
EXPECT_EQ(got_n, want_n);
calls++;
out->resize(got_n, 'x');
memset(buffer, 'x', got_n);
*bytes_transferred = got_n;
return Status::OK();
};
// If block_size, max_bytes, or both are zero, the cache is a pass-through.
@ -45,11 +58,11 @@ TEST(FileBlockCacheTest, PassThrough) {
FileBlockCache cache2(0, 1, 0, fetcher);
FileBlockCache cache3(0, 0, 0, fetcher);
std::vector<char> out;
TF_EXPECT_OK(cache1.Read(want_filename, want_offset, want_n, &out));
TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 1);
TF_EXPECT_OK(cache2.Read(want_filename, want_offset, want_n, &out));
TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 2);
TF_EXPECT_OK(cache3.Read(want_filename, want_offset, want_n, &out));
TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 3);
}
@ -63,13 +76,13 @@ TEST(FileBlockCacheTest, BlockAlignment) {
}
// The fetcher just fetches slices of the buffer.
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
if (offset < buf.size()) {
if (offset + n > buf.size()) {
out->insert(out->end(), buf.begin() + offset, buf.end());
} else {
out->insert(out->end(), buf.begin() + offset, buf.begin() + offset + n);
}
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
memcpy(buffer, buf.data() + offset, bytes_to_copy);
*bytes_transferred = bytes_to_copy;
} else {
*bytes_transferred = 0;
}
return Status::OK();
};
@ -80,7 +93,7 @@ TEST(FileBlockCacheTest, BlockAlignment) {
for (size_t offset = 0; offset < 10; offset++) {
for (size_t n = block_size - 2; n <= block_size + 2; n++) {
std::vector<char> got;
TF_EXPECT_OK(cache.Read("", offset, n, &got));
TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got));
// Verify the size of the read.
if (offset + n <= size) {
// Expect a full read.
@ -108,24 +121,27 @@ TEST(FileBlockCacheTest, CacheHits) {
const size_t block_size = 16;
std::set<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, std::vector<char>* out) {
size_t n, char* buffer,
size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
calls.insert(offset);
out->resize(n, 'x');
memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK();
};
const uint32 block_count = 256;
FileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
std::vector<char> out;
out.resize(block_count, 0);
// The cache has space for `block_count` blocks. The loop with i = 0 should
// fill the cache, and the loop with i = 1 should be all cache hits. The
// fetcher checks that it is called once and only once for each offset (to
// fetch the corresponding block).
for (int i = 0; i < 2; i++) {
for (int j = 0; j < block_count; j++) {
TF_EXPECT_OK(cache.Read("", block_size * j, block_size, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out));
}
}
}
@ -138,36 +154,39 @@ TEST(FileBlockCacheTest, OutOfRange) {
bool second_block = false;
auto fetcher = [block_size, file_size, &first_block, &second_block](
const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
size_t bytes_to_copy = 0;
if (offset == 0) {
// The first block (16 bytes) of the file.
out->resize(n, 'x');
memset(buffer, 'x', n);
bytes_to_copy = n;
first_block = true;
} else if (offset == block_size) {
// The second block (8 bytes) of the file.
out->resize(file_size - block_size, 'x');
bytes_to_copy = file_size - block_size;
memset(buffer, 'x', bytes_to_copy);
second_block = true;
}
*bytes_transferred = bytes_to_copy;
return Status::OK();
};
FileBlockCache cache(block_size, block_size, 0, fetcher);
std::vector<char> out;
// Reading the first 16 bytes should be fine.
TF_EXPECT_OK(cache.Read("", 0, block_size, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out));
EXPECT_TRUE(first_block);
EXPECT_EQ(out.size(), block_size);
// Reading at offset file_size + 4 will read the second block (since the read
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return
// OutOfRange because the offset is past the end of the 24-byte file.
Status status = cache.Read("", file_size + 4, 4, &out);
Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
EXPECT_TRUE(second_block);
EXPECT_EQ(out.size(), 0);
// Reading the second full block will return 8 bytes, from a cache hit.
second_block = false;
TF_EXPECT_OK(cache.Read("", block_size, block_size, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_FALSE(second_block);
EXPECT_EQ(out.size(), file_size - block_size);
}
@ -178,20 +197,22 @@ TEST(FileBlockCacheTest, Inconsistent) {
const size_t block_size = 16;
// This fetcher returns OK but only fills in one byte for any offset.
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
out->resize(1, 'x');
EXPECT_GE(n, 1);
memset(buffer, 'x', 1);
*bytes_transferred = 1;
return Status::OK();
};
FileBlockCache cache(block_size, 2 * block_size, 0, fetcher);
std::vector<char> out;
// Read the second block; this should yield an OK status and a single byte.
TF_EXPECT_OK(cache.Read("", block_size, block_size, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_EQ(out.size(), 1);
// Now read the first block; this should yield an INTERNAL error because we
// had already cached a partial block at a later position.
Status status = cache.Read("", 0, block_size, &out);
Status status = ReadCache(&cache, "", 0, block_size, &out);
EXPECT_EQ(status.code(), error::INTERNAL);
}
@ -199,14 +220,16 @@ TEST(FileBlockCacheTest, LRU) {
const size_t block_size = 16;
std::list<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, std::vector<char>* out) {
size_t n, char* buffer,
size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
if (!calls.empty()) {
EXPECT_EQ(offset, calls.front());
calls.pop_front();
}
out->resize(n, 'x');
memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK();
};
const uint32 block_count = 2;
@ -216,38 +239,39 @@ TEST(FileBlockCacheTest, LRU) {
// fetcher calls that the cache makes.
calls.push_back(0);
// Cache miss - drains an element from `calls`.
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Cache hit - does not drain an element from `calls`.
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
calls.push_back(block_size);
// Cache miss followed by cache hit.
TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
calls.push_back(2 * block_size);
// Cache miss followed by cache hit. Causes eviction of LRU element.
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// LRU element was at offset 0. Cache miss.
calls.push_back(0);
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Element at 2 * block_size is still in cache, and this read should update
// its position in the LRU list so it doesn't get evicted by the next read.
TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// Element at block_size was evicted. Reading this element will also cause
// the LRU element (at 0) to be evicted.
calls.push_back(block_size);
TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
// Element at 0 was evicted again.
calls.push_back(0);
TF_EXPECT_OK(cache.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
}
TEST(FileBlockCacheTest, MaxStaleness) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
calls++;
out->resize(n, 'x');
memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK();
};
std::vector<char> out;
@ -256,14 +280,14 @@ TEST(FileBlockCacheTest, MaxStaleness) {
// expected.
FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get());
// Execute the first read to load the block.
TF_EXPECT_OK(cache1.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
// Now advance the clock one second at a time and redo the read. The call
// count should advance every 3 seconds (i.e. every time the staleness is
// greater than 2).
for (int i = 1; i <= 10; i++) {
env->SetNowSeconds(i + 1);
TF_EXPECT_OK(cache1.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1 + i / 3);
}
// Now create a cache with max staleness of 0, and verify that it also works
@ -272,27 +296,27 @@ TEST(FileBlockCacheTest, MaxStaleness) {
env->SetNowSeconds(0);
FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get());
// Execute the first read to load the block.
TF_EXPECT_OK(cache2.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
// Advance the clock by a huge amount and verify that the cached block is
// used to satisfy the read.
env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun.
TF_EXPECT_OK(cache2.Read("", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
}
TEST(FileBlockCacheTest, RemoveFile) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
calls++;
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
if (offset > 0) {
// The first block is lower case and all subsequent blocks are upper case.
c = toupper(c);
}
out->clear();
out->resize(n, c);
memset(buffer, c, n);
*bytes_transferred = n;
return Status::OK();
};
// This cache has space for 4 blocks; we'll read from two files.
@ -304,41 +328,41 @@ TEST(FileBlockCacheTest, RemoveFile) {
std::vector<char> A(n, 'A');
std::vector<char> B(n, 'B');
// Fill the cache.
TF_EXPECT_OK(cache.Read("a", 0, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
EXPECT_EQ(calls, 1);
TF_EXPECT_OK(cache.Read("a", 8, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
EXPECT_EQ(calls, 2);
TF_EXPECT_OK(cache.Read("b", 0, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
EXPECT_EQ(calls, 3);
TF_EXPECT_OK(cache.Read("b", 8, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// All four blocks should be in the cache now.
TF_EXPECT_OK(cache.Read("a", 0, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
TF_EXPECT_OK(cache.Read("a", 8, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
TF_EXPECT_OK(cache.Read("b", 0, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
TF_EXPECT_OK(cache.Read("b", 8, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// Remove the blocks from "a".
cache.RemoveFile("a");
// Both blocks from "b" should still be there.
TF_EXPECT_OK(cache.Read("b", 0, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
TF_EXPECT_OK(cache.Read("b", 8, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// The blocks from "a" should not be there.
TF_EXPECT_OK(cache.Read("a", 0, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
EXPECT_EQ(calls, 5);
TF_EXPECT_OK(cache.Read("a", 8, n, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
EXPECT_EQ(calls, 6);
}
@ -346,10 +370,10 @@ TEST(FileBlockCacheTest, RemoveFile) {
TEST(FileBlockCacheTest, Prune) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
calls++;
out->clear();
out->resize(n, 'x');
memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK();
};
std::vector<char> out;
@ -360,20 +384,20 @@ TEST(FileBlockCacheTest, Prune) {
FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get());
// Read three blocks into the cache, and advance the timestamp by one second
// with each read. Start with a block of "a" at the current timestamp `now`.
TF_EXPECT_OK(cache.Read("a", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
// Now load a block of a different file "b" at timestamp `now` + 1
env->SetNowSeconds(now + 1);
TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
// Now load a different block of file "a" at timestamp `now` + 1. When the
// first block of "a" expires, this block should also be removed because it
// also belongs to file "a".
TF_EXPECT_OK(cache.Read("a", 8, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
// Ensure that all blocks are in the cache (i.e. reads are cache hits).
EXPECT_EQ(cache.CacheSize(), 24);
EXPECT_EQ(calls, 3);
TF_EXPECT_OK(cache.Read("a", 0, 1, &out));
TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
TF_EXPECT_OK(cache.Read("a", 8, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
EXPECT_EQ(calls, 3);
// Advance the fake timestamp so that "a" becomes stale via its first block.
env->SetNowSeconds(now + 2);
@ -389,7 +413,7 @@ TEST(FileBlockCacheTest, Prune) {
// There should be one block left in the cache, and it should be the first
// block of "b".
EXPECT_EQ(cache.CacheSize(), 8);
TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
EXPECT_EQ(calls, 3);
// Advance the fake time to `now` + 3, at which point "b" becomes stale.
env->SetNowSeconds(now + 3);
@ -409,14 +433,14 @@ TEST(FileBlockCacheTest, ParallelReads) {
const int callers = 4;
BlockingCounter counter(callers);
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
counter.DecrementCount();
if (!counter.WaitFor(std::chrono::seconds(10))) {
// This avoids having the test time out, which is harder to debug.
return errors::FailedPrecondition("desired concurrency not reached");
}
out->clear();
out->resize(n, 'x');
memset(buffer, 'x', n);
*bytes_transferred = n;
return Status::OK();
};
const int block_size = 8;
@ -426,7 +450,8 @@ TEST(FileBlockCacheTest, ParallelReads) {
threads.emplace_back(
Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() {
std::vector<char> out;
TF_EXPECT_OK(cache.Read("a", i * block_size, block_size, &out));
TF_EXPECT_OK(
ReadCache(&cache, "a", i * block_size, block_size, &out));
std::vector<char> x(block_size, 'x');
EXPECT_EQ(out, x);
}));
@ -443,11 +468,12 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
Notification notification;
auto fetcher = [&num_requests, &notification, block_size](
const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset, 0);
num_requests++;
out->resize(n, 'x');
memset(buffer, 'x', n);
*bytes_transferred = n;
notification.Notify();
// Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
@ -458,17 +484,16 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
std::unique_ptr<Thread> concurrent(
Env::Default()->StartThread({}, "concurrent", [&cache, block_size] {
std::vector<char> out;
TF_EXPECT_OK(cache.Read("", 0, block_size / 2, &out));
TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2);
}));
EXPECT_TRUE(WaitForNotificationWithTimeout(&notification, 10000))
<< "Timeout waiting for concurrent thread to start.";
std::vector<char> out;
TF_EXPECT_OK(cache.Read("", block_size / 2, block_size / 2, &out));
TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2);
EXPECT_EQ(1, num_requests);
}
} // namespace
} // namespace tensorflow

View File

@ -58,6 +58,10 @@ class TestHttpRequest : public HttpRequest {
Status SetResultBuffer(std::vector<char>* out_buffer) override {
return Status::OK();
}
Status SetResultBufferDirect(char* buffer, size_t size) override {
return Status::OK();
}
size_t GetResultBufferDirectBytesTransferred() override { return 0; }
string GetResponseHeader(const string& name) const override { return ""; }
uint64 GetResponseCode() const override { return 0; }

View File

@ -285,11 +285,11 @@ class GcsRandomAccessFile : public RandomAccessFile {
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
*result = StringPiece();
std::vector<char> out;
TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out));
std::memcpy(scratch, out.data(), std::min(out.size(), n));
*result = StringPiece(scratch, std::min(out.size(), n));
if (result->size() < n) {
size_t bytes_transferred;
TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch,
&bytes_transferred));
*result = StringPiece(scratch, bytes_transferred);
if (bytes_transferred < n) {
// This is not an error per se. The RandomAccessFile interface expects
// that Read returns OutOfRange if fewer bytes were read than requested.
return errors::OutOfRange("EOF reached, ", result->size(),
@ -721,15 +721,17 @@ std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
std::unique_ptr<FileBlockCache> file_block_cache(
new FileBlockCache(block_size, max_bytes, max_staleness,
[this](const string& filename, size_t offset, size_t n,
std::vector<char>* out) {
return LoadBufferFromGCS(filename, offset, n, out);
char* buffer, size_t* bytes_transferred) {
return LoadBufferFromGCS(filename, offset, n, buffer,
bytes_transferred);
}));
return file_block_cache;
}
// A helper function to actually read the data from GCS.
Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
size_t n, std::vector<char>* out) {
size_t n, char* buffer,
size_t* bytes_transferred) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object));
@ -739,21 +741,23 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket,
"/", request->EscapeString(object))));
TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
TF_RETURN_IF_ERROR(request->SetResultBuffer(out));
TF_RETURN_IF_ERROR(request->SetResultBufferDirect(buffer, n));
TF_RETURN_IF_ERROR(
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read));
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object);
size_t bytes_read = request->GetResultBufferDirectBytesTransferred();
*bytes_transferred = bytes_read;
VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
<< offset << " of size: " << out->size();
<< offset << " of size: " << bytes_read;
if (out->size() < block_size()) {
if (bytes_read < block_size()) {
// Check stat cache to see if we encountered an interrupted read.
FileStatistics stat;
if (stat_cache_->Lookup(filename, &stat)) {
if (offset + out->size() < stat.length) {
if (offset + bytes_read < stat.length) {
return errors::Internal(strings::Printf(
"File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset));

View File

@ -177,7 +177,7 @@ class GcsFileSystem : public FileSystem {
/// Loads file contents from GCS for a given filename, offset, and length.
Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n,
std::vector<char>* out);
char* buffer, size_t* bytes_transferred);
std::unique_ptr<AuthProvider> auth_provider_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_;

View File

@ -101,6 +101,20 @@ class HttpRequest {
/// read. Existing content of the vector will be cleared.
virtual Status SetResultBuffer(std::vector<char>* out_buffer) = 0;
/// \brief Specifies the buffer for receiving the response body.
///
/// This method should be used when a caller knows the upper bound of the
/// size of the response data. The caller provides a pre-allocated buffer
/// and its size. After the Send() method is called, the
/// GetResultBufferDirectBytesTransferred() method may be used to learn to the
/// number of bytes that were transferred using this method.
virtual Status SetResultBufferDirect(char* buffer, size_t size) = 0;
/// \brief Returns the number of bytes transferred, when using
/// SetResultBufferDirect(). This method may only be used when using
/// SetResultBufferDirect().
virtual size_t GetResultBufferDirectBytesTransferred() = 0;
/// \brief Returns the response headers of a completed request.
///
/// If the header is not found, returns an empty string.

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
#include <algorithm>
#include <fstream>
#include <string>
#include <vector>
@ -130,12 +131,25 @@ class FakeHttpRequest : public CurlHttpRequest {
buffer_ = buffer;
return Status::OK();
}
Status SetResultBufferDirect(char* buffer, size_t size) override {
direct_result_buffer_ = buffer;
direct_result_buffer_size_ = size;
return Status::OK();
}
size_t GetResultBufferDirectBytesTransferred() override {
return direct_result_bytes_transferred_;
}
Status Send() override {
EXPECT_EQ(expected_request_, actual_request())
<< "Unexpected HTTP request.";
if (buffer_) {
buffer_->insert(buffer_->begin(), response_.c_str(),
response_.c_str() + response_.size());
buffer_->insert(buffer_->begin(), response_.data(),
response_.data() + response_.size());
} else if (direct_result_buffer_ != nullptr) {
size_t bytes_to_copy =
std::min<size_t>(direct_result_buffer_size_, response_.size());
memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
direct_result_bytes_transferred_ += bytes_to_copy;
}
return response_status_;
}
@ -178,6 +192,9 @@ class FakeHttpRequest : public CurlHttpRequest {
}
std::vector<char>* buffer_ = nullptr;
char* direct_result_buffer_ = nullptr;
size_t direct_result_buffer_size_ = 0;
size_t direct_result_bytes_transferred_ = 0;
string expected_request_;
string actual_uri_;
string actual_request_;

View File

@ -136,15 +136,19 @@ void Env::GetLocalTempDirectories(std::vector<string>* list) {
// Directories, in order of preference. If we find a dir that
// exists, we stop adding other less-preferred dirs
const char* candidates[] = {
// Non-null only during unittest/regtest
getenv("TEST_TMPDIR"),
// Non-null only during unittest/regtest
getenv("TEST_TMPDIR"),
// Explicitly-supplied temp dirs
getenv("TMPDIR"),
getenv("TMP"),
// Explicitly-supplied temp dirs
getenv("TMPDIR"),
getenv("TMP"),
// If all else fails
"/tmp",
#if defined(__ANDROID__)
"/data/local/tmp",
#endif
// If all else fails
"/tmp",
};
for (const char* d : candidates) {

View File

@ -0,0 +1,127 @@
/*
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
/*
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/c_api.h"
*/
import "C"
import (
"errors"
"fmt"
"runtime"
"unsafe"
"github.com/golang/protobuf/proto"
pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
)
// Encapsulates a collection of API definitions.
//
// apiDefMap represents a map from operation name to corresponding
// ApiDef proto (see
// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto
// for ApiDef proto definition).
type apiDefMap struct {
c *C.TF_ApiDefMap
}
// Creates and returns a new apiDefMap instance.
//
// oplist is and OpList proto instance (see
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto
// for OpList proto definition).
func newAPIDefMap(oplist *pb.OpList) (*apiDefMap, error) {
// Create a buffer containing the serialized OpList.
opdefSerialized, err := proto.Marshal(oplist)
if err != nil {
return nil, fmt.Errorf("could not serialize OpDef for %s", oplist.String())
}
data := C.CBytes(opdefSerialized)
defer C.free(data)
opbuf := C.TF_NewBuffer()
defer C.TF_DeleteBuffer(opbuf)
opbuf.data = data
opbuf.length = C.size_t(len(opdefSerialized))
// Create ApiDefMap.
status := C.TF_NewStatus()
defer C.TF_DeleteStatus(status)
capimap := C.TF_NewApiDefMap(opbuf, status)
if C.TF_GetCode(status) != C.TF_OK {
return nil, errors.New(C.GoString(C.TF_Message(status)))
}
apimap := &apiDefMap{capimap}
runtime.SetFinalizer(
apimap,
func(a *apiDefMap) {
C.TF_DeleteApiDefMap(a.c)
})
return apimap, nil
}
// Updates apiDefMap with the overrides specified in `data`.
//
// data - ApiDef text proto.
func (m *apiDefMap) Put(data string) error {
cdata := C.CString(data)
defer C.free(unsafe.Pointer(cdata))
status := C.TF_NewStatus()
defer C.TF_DeleteStatus(status)
C.TF_ApiDefMapPut(m.c, cdata, C.size_t(len(data)), status)
if C.TF_GetCode(status) != C.TF_OK {
return errors.New(C.GoString(C.TF_Message(status)))
}
return nil
}
// Returns ApiDef proto instance for the TensorFlow operation
// named `opname`.
func (m *apiDefMap) Get(opname string) (*pb.ApiDef, error) {
cname := C.CString(opname)
defer C.free(unsafe.Pointer(cname))
status := C.TF_NewStatus()
defer C.TF_DeleteStatus(status)
apidefBuf := C.TF_ApiDefMapGet(
m.c, cname, C.size_t(len(opname)), status)
defer C.TF_DeleteBuffer(apidefBuf)
if C.TF_GetCode(status) != C.TF_OK {
return nil, errors.New(C.GoString(C.TF_Message(status)))
}
if apidefBuf == nil {
return nil, fmt.Errorf("could not find ApiDef for %s", opname)
}
var (
apidef = new(pb.ApiDef)
size = int(apidefBuf.length)
// A []byte backed by C memory.
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
data = (*[1 << 30]byte)(unsafe.Pointer(apidefBuf.data))[:size:size]
err = proto.Unmarshal(data, apidef)
)
if err != nil {
return nil, err
}
return apidef, nil
}

View File

@ -29,12 +29,18 @@ limitations under the License.
// encountered.
package internal
// #include "tensorflow/c/c_api.h"
/*
#include <stdlib.h>
#include "tensorflow/c/c_api.h"
*/
import "C"
import (
"fmt"
"io"
"io/ioutil"
"path"
"reflect"
"strings"
"text/template"
@ -47,15 +53,23 @@ import (
// GenerateFunctionsForRegisteredOps writes a Go source code file to w
// containing functions for each TensorFlow operation registered in the address
// space of the calling process.
func GenerateFunctionsForRegisteredOps(w io.Writer) error {
ops, err := registeredOps()
// apidefDirs should be a contain of directories containing api_def_*.pbtxt
// files to load.
func GenerateFunctionsForRegisteredOps(
w io.Writer, apidefDirs []string) error {
ops, apimap, err := registeredOps()
if err != nil {
return err
}
return generateFunctionsForOps(w, ops)
for _, dir := range apidefDirs {
if err = updateAPIDefs(apimap, dir); err != nil {
return err
}
}
return generateFunctionsForOps(w, ops, apimap)
}
func registeredOps() (*pb.OpList, error) {
func registeredOps() (*pb.OpList, *apiDefMap, error) {
buf := C.TF_GetAllOpList()
defer C.TF_DeleteBuffer(buf)
var (
@ -66,10 +80,31 @@ func registeredOps() (*pb.OpList, error) {
data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size]
err = proto.Unmarshal(data, list)
)
return list, err
if err != nil {
return nil, nil, err
}
apimap, err := newAPIDefMap(list)
return list, apimap, err
}
func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error {
func updateAPIDefs(m *apiDefMap, dir string) error {
files, err := ioutil.ReadDir(dir)
if err != nil {
return err
}
for _, file := range files {
data, err := ioutil.ReadFile(path.Join(dir, file.Name()))
if err != nil {
return fmt.Errorf("failed to read %q: %v", file.Name(), err)
}
if err = m.Put(string(data)); err != nil {
return fmt.Errorf("failed to process %q: %v", file.Name(), err)
}
}
return nil
}
func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error {
thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
if err := tmplHeader.Execute(w, thisPackage); err != nil {
return err
@ -83,14 +118,18 @@ func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error {
if blacklist[op.Name] {
continue
}
if err := generateFunctionForOp(w, op); err != nil {
apidef, err := apimap.Get(op.Name)
if err != nil {
return err
}
if err := generateFunctionForOp(w, op, apidef); err != nil {
return err
}
}
return nil
}
func generateFunctionForOp(w io.Writer, op *pb.OpDef) error {
func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error {
if strings.HasPrefix(op.Name, "_") { // Internal operation
return nil
}
@ -112,12 +151,16 @@ func generateFunctionForOp(w io.Writer, op *pb.OpDef) error {
return nil
}
}
if op.Summary == "" {
if apidef.Summary == "" {
// Undocumented operation, perhaps a sign of not being ready to
// export.
return nil
}
return tmplOp.Execute(w, newTmplArgs(op))
tmplArgs, err := newTmplArgs(op, apidef)
if err != nil {
return err
}
return tmplOp.Execute(w, tmplArgs)
}
var (
@ -172,7 +215,7 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in
type {{.Op.Name}}Attr func(optionalAttr)
{{range .OptionalAttrs}}
// {{$.Op.Name}}{{CamelCase .Name}} sets the optional {{.Name}} attribute to value.
// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value.
{{- if .Description}}
//
// value: {{MakeComment .Description}}
@ -180,9 +223,9 @@ type {{.Op.Name}}Attr func(optionalAttr)
// If not specified, defaults to {{StripLeadingColon .DefaultValue}}
{{- if .HasMinimum}}
//
// {{if IsListAttr .}}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
{{- end}}
func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
return func(m optionalAttr) {
m[{{printf "%q" .Name}}] = value
}
@ -192,14 +235,14 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
{{- /* Create a godoc friendly comment. */ -}}
// {{MakeComment .Op.Summary}}
// {{MakeComment .APIDef.Summary}}
{{- with .Op.Deprecation}}
//
// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
{{- end -}}
{{- with .Op.Description}}
{{- with .APIDef.Description}}
//
// {{MakeComment .}}
{{- end -}}
@ -207,11 +250,11 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
{{- if .DescribeArguments}}
//
// Arguments:
{{- range .Op.InputArg}}
// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}}
{{- range .InArgsReordered}}
// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
{{- end -}}
{{- range .RequiredAttrs}}
// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}}
// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}}
{{- end -}}
{{- end -}}
@ -221,12 +264,12 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr
{{- else }}
{{- if .DescribeOutputs}}
//
{{- if ((len .Op.OutputArg) eq 1) }}
// Returns {{range .Op.OutputArg}}{{MakeComment .Description}}{{end}}
{{- if ((len .OutArgs) eq 1) }}
// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}}
{{- else }}
// Returns:
{{- range .Op.OutputArg}}
// {{Identifier .Name}}{{if .Description}}: {{MakeComment .Description}}{{end}}
{{- range .OutArgs}}
// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}}
{{- end -}}
{{- end -}}
{{- end -}}
@ -247,15 +290,15 @@ func {{.Op.Name}}
*/ -}}
(scope *Scope
{{- range $i, $a := .Op.InputArg}}, {{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}
{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.Name}} {{GoType $a.Type}}{{end -}}
{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}
{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}}
{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
)
{{- /* Construct outputs: len(OpDef.OutputArg) or a *tf.Operation */ -}}
{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}}
{{if .Op.OutputArg -}}
({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}})
{{if .OutArgs -}}
({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}})
{{- else -}}
(o *tf.Operation)
{{- end }} {
@ -263,7 +306,7 @@ func {{.Op.Name}}
return
}
{{if .HasAttrs -}}
attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}}
attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}}
{{if .OptionalAttrs -}}
for _, a := range optional {
a(attrs)
@ -272,16 +315,16 @@ func {{.Op.Name}}
{{end -}}
opspec := tf.OpSpec{
Type: {{printf "%q" .Op.Name}},
{{if .Op.InputArg -}}
{{if .InArgs -}}
Input: []tf.Input{
{{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}}
{{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}}
},
{{- end}}
{{- if .HasAttrs}}
Attrs: attrs,
{{- end}}
}
{{- if .Op.OutputArg}}
{{- if .OutArgs}}
{{- if .HasListOutput}}
op := scope.AddOperation(opspec)
if scope.Err() != nil {
@ -289,43 +332,105 @@ func {{.Op.Name}}
}
var idx int
var err error
{{- range $i, $a := .Op.OutputArg}}
{{- if IsListArg $a}}
if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
{{- range $i, $a := .OutArgs}}
{{- if $a.IsListArg}}
if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
return
}
{{- else }}
{{Identifier .Name}} = op.Output(idx)
{{Identifier .RenameTo}} = op.Output(idx)
{{- end }}{{- /* if IsListArg */}}
{{- end }}{{- /* range .Op.OutputArg */}}
return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier .Name}}{{end}}
{{- end }}{{- /* range .OutArgs */}}
return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}}
{{- else }}
op := scope.AddOperation(opspec)
return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
{{- end }}{{- /* if .HasListOutput */}}
{{- else }}
return scope.AddOperation(opspec)
{{- end }}{{- /* if .Op.OutputArg */}}
{{- end }}{{- /* if .OutArgs */}}
}
`))
)
type attrWrapper struct {
op *pb.OpDef_AttrDef
api *pb.ApiDef_Attr
}
func (a *attrWrapper) Name() string { return a.api.Name }
func (a *attrWrapper) RenameTo() string { return a.api.RenameTo }
func (a *attrWrapper) Description() string { return a.api.Description }
func (a *attrWrapper) Type() string { return a.op.Type }
func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) }
func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum }
func (a *attrWrapper) Minimum() int64 { return a.op.Minimum }
func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue }
type argWrapper struct {
op *pb.OpDef_ArgDef
api *pb.ApiDef_Arg
}
func (a *argWrapper) Name() string { return a.api.Name }
func (a *argWrapper) RenameTo() string { return a.api.RenameTo }
func (a *argWrapper) Description() string { return a.api.Description }
func (a *argWrapper) IsListArg() bool { return isListArg(a.op) }
type tmplArgs struct {
Op *pb.OpDef
Op *pb.OpDef
APIDef *pb.ApiDef
// Op.Attr is split into two categories
// (1) Required: These must be specified by the client and are thus
// included in the function signature.
// (2) Optional: These need not be specified (as they have default
// values) and thus do not appear in the function signature.
RequiredAttrs []*pb.OpDef_AttrDef
OptionalAttrs []*pb.OpDef_AttrDef
RequiredAttrs []*attrWrapper
OptionalAttrs []*attrWrapper
InArgs []*argWrapper
// Input arguments ordered based on arg_order field of ApiDef.
InArgsReordered []*argWrapper
OutArgs []*argWrapper
}
func newTmplArgs(op *pb.OpDef) *tmplArgs {
ret := tmplArgs{Op: op}
func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) {
ret := tmplArgs{Op: op, APIDef: apidef}
// Setup InArgs field
for i, in := range op.InputArg {
argCombined := argWrapper{op: in, api: apidef.InArg[i]}
ret.InArgs = append(ret.InArgs, &argCombined)
}
// Setup OutArgs field
for i, out := range op.OutputArg {
argCombined := argWrapper{op: out, api: apidef.OutArg[i]}
ret.OutArgs = append(ret.OutArgs, &argCombined)
}
// Setup InArgsReordered field
for _, argName := range apidef.ArgOrder {
// Find the argument in op.InputArg
argIndex := -1
for i, in := range op.InputArg {
if in.Name == argName {
argIndex = i
break
}
}
if argIndex == -1 {
return nil, fmt.Errorf(
"couldn't find argument %s in ApiDef for op %s",
argName, op.Name)
}
argCombined := argWrapper{
op: op.InputArg[argIndex], api: apidef.InArg[argIndex]}
ret.InArgsReordered = append(ret.InArgsReordered, &argCombined)
}
if len(op.Attr) == 0 {
return &ret
return &ret, nil
}
// Attributes related to the InputArg's type are inferred automatically
// and are not exposed to the client.
@ -341,28 +446,29 @@ func newTmplArgs(op *pb.OpDef) *tmplArgs {
inferred[in.NumberAttr] = true
}
}
for _, attr := range op.Attr {
for i, attr := range op.Attr {
if inferred[attr.Name] {
continue
}
attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]}
if attr.DefaultValue == nil {
ret.RequiredAttrs = append(ret.RequiredAttrs, attr)
ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined)
} else {
ret.OptionalAttrs = append(ret.OptionalAttrs, attr)
ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined)
}
}
return &ret
return &ret, nil
}
func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
func (a *tmplArgs) DescribeArguments() bool {
for _, arg := range a.Op.InputArg {
if arg.Description != "" {
for _, arg := range a.InArgs {
if arg.Description() != "" {
return true
}
}
for _, attr := range a.RequiredAttrs {
if attr.Description != "" {
if attr.Description() != "" {
return true
}
}
@ -370,16 +476,16 @@ func (a *tmplArgs) DescribeArguments() bool {
}
func (a *tmplArgs) DescribeOutputs() bool {
for _, arg := range a.Op.OutputArg {
if arg.Description != "" {
for _, arg := range a.OutArgs {
if arg.Description() != "" {
return true
}
}
return false
}
func (a *tmplArgs) HasListOutput() bool {
for _, arg := range a.Op.OutputArg {
if isListArg(arg) {
for _, arg := range a.OutArgs {
if arg.IsListArg() {
return true
}
}

View File

@ -25,19 +25,44 @@ import (
pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
)
// Creates an ApiDef based on opdef and applies overrides
// from apidefText (ApiDef text proto).
func GetAPIDef(t *testing.T, opdef *pb.OpDef, apidefText string) *pb.ApiDef {
opdefList := &pb.OpList{Op: []*pb.OpDef{opdef}}
apimap, err := newAPIDefMap(opdefList)
if err != nil {
t.Fatal(err)
}
err = apimap.Put(apidefText)
if err != nil {
t.Fatal(err)
}
apidef, err := apimap.Get(opdef.Name)
if err != nil {
t.Fatal(err)
}
return apidef
}
func TestGenerateOp(t *testing.T) {
// TestGenerateOp validates the generated source code for an op.
// The OpDef for the test cases are simplified forms of real ops.
testdata := []struct {
tag string
opdef string
apidef string
wanted string
}{
{
tag: "NoOp",
opdef: `
name: "NoOp"
`,
apidef: `
op: <
graph_op_name: "NoOp"
summary: "No. Op."
>
`,
wanted: `
// No. Op.
@ -80,8 +105,13 @@ attr: <
>
>
>
`,
apidef: `
op: <
graph_op_name: "Add"
summary: "Returns x + y element-wise."
description: "Blah blah",
>
`,
wanted: `
// Returns x + y element-wise.
@ -122,7 +152,12 @@ attr: <
name: "DstT"
type: "type"
>
`,
apidef: `
op: <
graph_op_name: "Cast"
summary: "Cast x of type SrcT to y of DstT."
>
`,
wanted: `
// Cast x of type SrcT to y of DstT.
@ -149,12 +184,10 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) {
name: "DecodeJpeg"
input_arg: <
name: "contents"
description: "0-D. The JPEG-encoded image."
type: DT_STRING
>
output_arg: <
name: "image"
description: "3-D with shape [height, width, channels]"
type: DT_UINT8
>
attr: <
@ -163,7 +196,6 @@ attr: <
default_value: <
i: 0
>
description: "Number of color channels for the decoded image."
>
attr: <
name: "fancy_upscaling"
@ -171,7 +203,6 @@ attr: <
default_value: <
b: true
>
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
>
attr: <
name: "acceptable_fraction"
@ -179,10 +210,34 @@ attr: <
default_value: <
f: 1
>
>
`,
apidef: `
op: <
graph_op_name: "DecodeJpeg"
in_arg: <
name: "contents"
description: "0-D. The JPEG-encoded image."
>
out_arg: <
name: "image"
description: "3-D with shape [height, width, channels]"
>
attr: <
name: "channels"
description: "Number of color channels for the decoded image."
>
attr: <
name: "fancy_upscaling"
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
>
attr: <
name: "acceptable_fraction"
description: "The minimum required fraction of lines before a truncated\ninput is accepted."
>
summary: "Decode a JPEG-encoded image to a uint8 tensor."
description: "Norna dorna fjord\nkajorna\nhahaha"
>
`,
wanted: `
// DecodeJpegAttr is an optional argument to DecodeJpeg.
@ -270,7 +325,12 @@ attr: <
name: "T"
type: "type"
>
`,
apidef: `
op: <
graph_op_name: "TwoOutputs"
summary: "Op that produces multiple outputs"
>
`,
wanted: `
// Op that produces multiple outputs
@ -326,8 +386,13 @@ attr: <
>
>
>
`,
apidef: `
op: <
graph_op_name: "ShapeN"
summary: "Returns shape of tensors."
description: "Some description here."
>
`,
wanted: `
// ShapeNAttr is an optional argument to ShapeN.
@ -371,6 +436,102 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
}
return output
}
`,
},
{
tag: "ApiDefOverrides",
opdef: `
name: "TestOp"
input_arg: <
name: "a"
type: DT_STRING
>
input_arg: <
name: "b"
type: DT_STRING
>
output_arg: <
name: "c"
type: DT_UINT8
>
attr: <
name: "d"
type: "int"
default_value: <
i: 0
>
>
`,
apidef: `
op: <
graph_op_name: "TestOp"
in_arg: <
name: "a"
rename_to: "aa"
description: "Description for aa."
>
in_arg: <
name: "b"
rename_to: "bb"
description: "Description for bb."
>
arg_order: "b"
arg_order: "a"
out_arg: <
name: "c"
rename_to: "cc"
description: "Description for cc."
>
attr: <
name: "d"
rename_to: "dd"
description: "Description for dd."
>
summary: "Summary for TestOp."
description: "Description for TestOp."
>
`,
wanted: `
// TestOpAttr is an optional argument to TestOp.
type TestOpAttr func(optionalAttr)
// TestOpDd sets the optional dd attribute to value.
//
// value: Description for dd.
// If not specified, defaults to 0
func TestOpDd(value int64) TestOpAttr {
return func(m optionalAttr) {
m["d"] = value
}
}
// Summary for TestOp.
//
// Description for TestOp.
//
// Arguments:
// bb: Description for bb.
// aa: Description for aa.
//
// Returns Description for cc.
func TestOp(scope *Scope, bb tf.Output, aa tf.Output, optional ...TestOpAttr) (cc tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "TestOp",
Input: []tf.Input{
aa, bb,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
`,
},
}
@ -378,11 +539,13 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
for _, test := range testdata {
t.Run(test.tag, func(t *testing.T) {
var opdef pb.OpDef
var apidef *pb.ApiDef
var buf bytes.Buffer
if err := proto.UnmarshalText(test.opdef, &opdef); err != nil {
t.Fatal(err)
}
if err := generateFunctionForOp(&buf, &opdef); err != nil {
apidef = GetAPIDef(t, &opdef, test.apidef)
if err := generateFunctionForOp(&buf, &opdef, apidef); err != nil {
t.Fatal(err)
}
got, err := format.Source(buf.Bytes())

View File

@ -27,15 +27,17 @@ import (
"log"
"os"
"path/filepath"
"strings"
"github.com/tensorflow/tensorflow/tensorflow/go/genop/internal"
)
func main() {
var (
filename = flag.String("outfile", "", "File to write generated source code to.")
header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty")
buf bytes.Buffer
filename = flag.String("outfile", "", "File to write generated source code to.")
header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty")
apiDefDirs = flag.String("api_def_dirs", "", "Comma-separated directories containing api_def_*.pbtxt files.")
buf bytes.Buffer
)
flag.Parse()
if *filename == "" {
@ -51,7 +53,13 @@ func main() {
}
os.MkdirAll(filepath.Dir(*filename), 0755)
if err := internal.GenerateFunctionsForRegisteredOps(&buf); err != nil {
apiDefDirsList := []string{}
if len(*apiDefDirs) > 0 {
apiDefDirsList = strings.Split(*apiDefDirs, ",")
}
if err := internal.GenerateFunctionsForRegisteredOps(
&buf, apiDefDirsList); err != nil {
log.Fatal(err)
}
formatted, err := format.Source(buf.Bytes())

View File

@ -52,6 +52,12 @@ py_library(
]),
)
py_library(
name = "common",
srcs = ["lib/common.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "debug_graphs",
srcs = ["lib/debug_graphs.py"],
@ -117,6 +123,7 @@ py_library(
srcs = ["lib/source_remote.py"],
srcs_version = "PY2AND3",
deps = [
":common",
":debug_service_pb2_grpc",
"//tensorflow/core/debug:debug_service_proto_py",
"//tensorflow/python/profiler:tfprof_logger",
@ -193,6 +200,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":command_parser",
":common",
":debugger_cli_common",
":tensor_format",
"//tensorflow/python:framework_for_generated_wrappers",
@ -334,7 +342,11 @@ py_library(
name = "grpc_wrapper",
srcs = ["wrappers/grpc_wrapper.py"],
srcs_version = "PY2AND3",
deps = [":framework"],
deps = [
":common",
":framework",
":source_remote",
],
)
py_library(
@ -345,6 +357,7 @@ py_library(
":analyzer_cli",
":cli_shared",
":command_parser",
":common",
":debug_data",
":debugger_cli_common",
":framework",
@ -439,6 +452,20 @@ py_binary(
],
)
py_test(
name = "common_test",
size = "small",
srcs = ["lib/common_test.py"],
srcs_version = "PY2AND3",
deps = [
":common",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "debug_graphs_test",
size = "small",

View File

@ -55,7 +55,8 @@ def no_rewrite_session_config():
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)

View File

@ -25,6 +25,7 @@ import six
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import tensor_format
from tensorflow.python.debug.lib import common
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
@ -214,51 +215,6 @@ def error(msg):
RL("ERROR: " + msg, COLOR_RED)])
def get_graph_element_name(elem):
"""Obtain the name or string representation of a graph element.
If the graph element has the attribute "name", return name. Otherwise, return
a __str__ representation of the graph element. Certain graph elements, such as
`SparseTensor`s, do not have the attribute "name".
Args:
elem: The graph element in question.
Returns:
If the attribute 'name' is available, return the name. Otherwise, return
str(fetch).
"""
return elem.name if hasattr(elem, "name") else str(elem)
def _get_fetch_names(fetches):
"""Get a flattened list of the names in run() call fetches.
Args:
fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an
Operation or a Variable. It may also be nested lists, tuples or
dicts. See doc of `Session.run()` for more details.
Returns:
(list of str) A flattened list of fetch names from `fetches`.
"""
lines = []
if isinstance(fetches, (list, tuple)):
for fetch in fetches:
lines.extend(_get_fetch_names(fetch))
elif isinstance(fetches, dict):
for key in fetches:
lines.extend(_get_fetch_names(fetches[key]))
else:
# This ought to be a Tensor, an Operation or a Variable, for which the name
# attribute should be available. (Bottom-out condition of the recursion.)
lines.append(get_graph_element_name(fetches))
return lines
def _recommend_command(command, description, indent=2, create_link=False):
"""Generate a RichTextLines object that describes a recommended command.
@ -327,14 +283,14 @@ def get_run_start_intro(run_call_count,
(RichTextLines) Formatted intro message about the `Session.run()` call.
"""
fetch_lines = _get_fetch_names(fetches)
fetch_lines = common.get_flattened_names(fetches)
if not feed_dict:
feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")]
else:
feed_dict_lines = []
for feed_key in feed_dict:
feed_key_name = get_graph_element_name(feed_key)
feed_key_name = common.get_graph_element_name(feed_key)
feed_dict_line = debugger_cli_common.RichLine(" ")
feed_dict_line += debugger_cli_common.RichLine(
feed_key_name,
@ -446,10 +402,10 @@ def get_run_short_description(run_call_count,
description = "run #%d: " % run_call_count
if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
description += "1 fetch (%s); " % get_graph_element_name(fetches)
description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
else:
# Could be (nested) list, tuple, dict or namedtuple.
num_fetches = len(_get_fetch_names(fetches))
num_fetches = len(common.get_flattened_names(fetches))
if num_fetches > 1:
description += "%d fetches; " % num_fetches
else:

View File

@ -0,0 +1,87 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common values and methods for TensorFlow Debugger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
GRPC_URL_PREFIX = "grpc://"
# A key for a Session.run() call.
RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"])
def get_graph_element_name(elem):
"""Obtain the name or string representation of a graph element.
If the graph element has the attribute "name", return name. Otherwise, return
a __str__ representation of the graph element. Certain graph elements, such as
`SparseTensor`s, do not have the attribute "name".
Args:
elem: The graph element in question.
Returns:
If the attribute 'name' is available, return the name. Otherwise, return
str(fetch).
"""
return elem.name if hasattr(elem, "name") else str(elem)
def get_flattened_names(feeds_or_fetches):
"""Get a flattened list of the names in run() call feeds or fetches.
Args:
feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe
a Tensor, an Operation or a Variable. It may also be nested lists, tuples
or dicts. See doc of `Session.run()` for more details.
Returns:
(list of str) A flattened list of fetch names from `feeds_or_fetches`.
"""
lines = []
if isinstance(feeds_or_fetches, (list, tuple)):
for item in feeds_or_fetches:
lines.extend(get_flattened_names(item))
elif isinstance(feeds_or_fetches, dict):
for key in feeds_or_fetches:
lines.extend(get_flattened_names(feeds_or_fetches[key]))
else:
# This ought to be a Tensor, an Operation or a Variable, for which the name
# attribute should be available. (Bottom-out condition of the recursion.)
lines.append(get_graph_element_name(feeds_or_fetches))
return lines
def get_run_key(feed_dict, fetches):
"""Summarize the names of feeds and fetches as a RunKey JSON string.
Args:
feed_dict: The feed_dict given to the `Session.run()` call.
fetches: The fetches from the `Session.run()` call.
Returns:
A JSON Array consisting of two items. They first items is a flattened
Array of the names of the feeds. The second item is a flattened Array of
the names of the fetches.
"""
return json.dumps(RunKey(get_flattened_names(feed_dict),
get_flattened_names(fetches)))

View File

@ -0,0 +1,59 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for common values and methods of TensorFlow Debugger."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from tensorflow.python.debug.lib import common
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class CommonTest(test_util.TensorFlowTestCase):
def testOnFeedOneFetch(self):
a = constant_op.constant(10.0, name="a")
b = constant_op.constant(20.0, name="b")
run_key = common.get_run_key({"a": a}, [b])
loaded = json.loads(run_key)
self.assertItemsEqual(["a:0"], loaded[0])
self.assertItemsEqual(["b:0"], loaded[1])
def testGetRunKeyFlat(self):
a = constant_op.constant(10.0, name="a")
b = constant_op.constant(20.0, name="b")
run_key = common.get_run_key({"a": a}, [a, b])
loaded = json.loads(run_key)
self.assertItemsEqual(["a:0"], loaded[0])
self.assertItemsEqual(["a:0", "b:0"], loaded[1])
def testGetRunKeyNestedFetches(self):
a = constant_op.constant(10.0, name="a")
b = constant_op.constant(20.0, name="b")
c = constant_op.constant(30.0, name="c")
d = constant_op.constant(30.0, name="d")
run_key = common.get_run_key(
{}, {"set1": [a, b], "set2": {"c": c, "d": d}})
loaded = json.loads(run_key)
self.assertItemsEqual([], loaded[0])
self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
if __name__ == "__main__":
googletest.main()

View File

@ -310,7 +310,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
op_log_proto.id_to_string)
raise ValueError(
"Op '%s' does not exist in the tracebacks received by the debug "
"server.")
"server." % op_name)
def query_origin_stack(self):
"""Query the stack of the origin of the execution call.
@ -348,6 +348,9 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
Raises:
ValueError: If no source file is found at the given file_path.
"""
if not self._source_files:
raise ValueError(
"This debug server has not received any source file contents yet.")
for source_file_proto in self._source_files.source_files:
if source_file_proto.file_path == file_path:
return source_file_proto.lines[lineno - 1]

View File

@ -248,7 +248,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
self.assertEqual(
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testTensorBoardDebugHooWorks(self):
def testTensorBoardDebugHookWorks(self):
u = variables.Variable(2.1, name="u")
v = variables.Variable(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
@ -261,8 +261,37 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
["localhost:%d" % self._server_port])
sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
# Activate watch point on some a tensor before calling sess.run().
self._server.request_watch("u/read", 0, "DebugIdentity")
self.assertAllClose(42.0, sess.run(w))
# self.assertAllClose(42.0, sess.run(w))
dump = debug_data.DebugDumpDir(self._dump_root)
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
# Check that the server has received the stack trace.
self.assertTrue(self._server.query_op_traceback("u"))
self.assertTrue(self._server.query_op_traceback("u/read"))
self.assertTrue(self._server.query_op_traceback("v"))
self.assertTrue(self._server.query_op_traceback("v/read"))
self.assertTrue(self._server.query_op_traceback("w"))
# Check that the server has received the python file content.
# Query an arbitrary line to make sure that is the case.
with open(__file__, "rt") as this_source_file:
first_line = this_source_file.readline().strip()
self.assertEqual(
first_line, self._server.query_source_file_line(__file__, 1))
self._server.clear_data()
# Call sess.run() again, and verify that this time the traceback and source
# code is not sent, because the graph version is not newer.
self.assertAllClose(42.0, sess.run(w))
with self.assertRaises(ValueError):
self._server.query_op_traceback("delta_1")
with self.assertRaises(ValueError):
self._server.query_source_file_line(__file__, 1)
def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
hooks.GrpcDebugHook(["grpc://foo:42424"])
hooks.GrpcDebugHook(["foo:42424"])
@ -748,6 +777,28 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
# to disable the breakpoint at delta:0:DebugIdentity.
self.assertSetEqual(set(), self._server_1.breakpoints)
if i == 0:
# Check that the server has received the stack trace.
self.assertTrue(self._server_1.query_op_traceback("delta_1"))
self.assertTrue(self._server_1.query_op_traceback("delta_2"))
self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
# Check that the server has received the python file content.
# Query an arbitrary line to make sure that is the case.
with open(__file__, "rt") as this_source_file:
first_line = this_source_file.readline().strip()
self.assertEqual(
first_line, self._server_1.query_source_file_line(__file__, 1))
else:
# In later Session.run() calls, the traceback shouldn't have been sent
# because it is already sent in the 1st call. So calling
# query_op_traceback() should lead to an exception, because the test
# debug server clears the data at the beginning of every iteration.
with self.assertRaises(ValueError):
self._server_1.query_op_traceback("delta_1")
with self.assertRaises(ValueError):
self._server_1.query_source_file_line(__file__, 1)
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess:
v = variables.Variable(50.0, name="v")

View File

@ -24,6 +24,7 @@ import grpc
from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.debug.lib import source_utils
from tensorflow.python.platform import gfile
@ -130,6 +131,11 @@ def _send_call_tracebacks(destinations,
"""
if not isinstance(destinations, list):
destinations = [destinations]
# Strip grpc:// prefix, if any is present.
destinations = [
dest[len(common.GRPC_URL_PREFIX):]
if dest.startswith(common.GRPC_URL_PREFIX) else dest
for dest in destinations]
call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
if is_eager_execution

View File

@ -17,15 +17,55 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import traceback
# Google-internal import(s).
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.wrappers import framework
def publish_traceback(debug_server_urls,
graph,
feed_dict,
fetches,
old_graph_version):
"""Publish traceback and source code if graph version is new.
`graph.version` is compared with `old_graph_version`. If the former is higher
(i.e., newer), the graph traceback and the associated source code is sent to
the debug server at the specified gRPC URLs.
Args:
debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of
debug server URLs.
graph: A Python `tf.Graph` object.
feed_dict: Feed dictionary given to the `Session.run()` call.
fetches: Fetches from the `Session.run()` call.
old_graph_version: Old graph version to compare to.
Returns:
If `graph.version > old_graph_version`, the new graph version as an `int`.
Else, the `old_graph_version` is returned.
"""
# TODO(cais): Consider moving this back to the top, after grpc becomes a
# pip dependency of tensorflow or tf_debug.
# pylint:disable=g-import-not-at-top
from tensorflow.python.debug.lib import source_remote
# pylint:enable=g-import-not-at-top
if graph.version > old_graph_version:
run_key = common.get_run_key(feed_dict, fetches)
source_remote.send_graph_tracebacks(
debug_server_urls, run_key, traceback.extract_stack(), graph,
send_source=True)
return graph.version
else:
return old_graph_version
class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
"""Debug Session wrapper that send debug data to gRPC stream(s)."""
_GRPC_URL_PREFIX = "grpc://"
def __init__(self,
sess,
grpc_debug_server_addresses,
@ -94,8 +134,8 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
return self._grpc_debug_server_urls
def _normalize_grpc_url(self, address):
return (self._GRPC_URL_PREFIX + address
if not address.startswith(self._GRPC_URL_PREFIX) else address)
return (common.GRPC_URL_PREFIX + address
if not address.startswith(common.GRPC_URL_PREFIX) else address)
class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
@ -126,3 +166,25 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
watch_fn=_gated_grpc_watch_fn,
thread_name_filter=thread_name_filter,
log_usage=log_usage)
# Keeps track of the latest version of Python graph object that has been
# sent to the debug servers.
self._sent_graph_version = -sys.maxint
def run(self,
fetches,
feed_dict=None,
options=None,
run_metadata=None,
callable_runner=None,
callable_runner_args=None):
self._sent_graph_version = publish_traceback(
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
self._sent_graph_version)
return super(TensorBoardDebugWrapperSession, self).run(
fetches,
feed_dict=feed_dict,
options=options,
run_metadata=run_metadata,
callable_runner=callable_runner,
callable_runner_args=callable_runner_args)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.debug.lib import debug_utils
from tensorflow.python.debug.lib import stepper
@ -331,3 +333,13 @@ class TensorBoardDebugHook(GrpcDebugHook):
watch_fn=_gated_grpc_watch_fn,
thread_name_filter=thread_name_filter,
log_usage=log_usage)
self._grpc_debug_server_addresses = grpc_debug_server_addresses
self._sent_graph_version = -sys.maxint
def before_run(self, run_context):
self._sent_graph_version = grpc_wrapper.publish_traceback(
self._grpc_debug_server_addresses, run_context.session.graph,
run_context.original_args.feed_dict, run_context.original_args.fetches,
self._sent_graph_version)
return super(TensorBoardDebugHook, self).before_run(run_context)

View File

@ -31,6 +31,7 @@ from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.debug.cli import profile_analyzer_cli
from tensorflow.python.debug.cli import stepper_cli
from tensorflow.python.debug.cli import ui_factory
from tensorflow.python.debug.lib import common
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import framework
@ -464,7 +465,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
feed_key = None
feed_value = None
for key in self._feed_dict:
key_name = cli_shared.get_graph_element_name(key)
key_name = common.get_graph_element_name(key)
if key_name == tensor_name:
feed_key = key_name
feed_value = self._feed_dict[key]
@ -561,7 +562,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
list(self._tensor_filters.keys()))
if self._feed_dict:
# Register tab completion for feed_dict keys.
feed_keys = [cli_shared.get_graph_element_name(key)
feed_keys = [common.get_graph_element_name(key)
for key in self._feed_dict.keys()]
curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys)

View File

@ -495,14 +495,13 @@ class GraphModeFunction(object):
def _get_defun_inputs(args):
"""Maps the inputs args to graph inputs."""
ret = []
for a in args:
flat_args = nest.flatten(args)
for a in flat_args:
if isinstance(a, ops.Tensor):
ret.append(graph_placeholder(a.dtype, a.shape))
elif type(a) in (tuple, list):
ret.append(_get_defun_inputs(a))
else:
ret.append(a)
return tuple(ret) if type(args) is tuple else ret
return nest.pack_sequence_as(args, ret)
def _defun_internal(name, func, args, kwds):
@ -582,8 +581,10 @@ def _cache_key(x):
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
if isinstance(x, np.ndarray):
return ("array", x.shape, tuple(x.reshape(-1)))
if type(x) in (list, tuple):
if isinstance(x, (list, tuple)):
return tuple([_cache_key(a) for a in x])
if isinstance(x, dict):
return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items())
return x

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
@ -57,6 +59,20 @@ class FunctionTest(test.TestCase):
out = sq(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsGraphMode(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@function.defun
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = a_times_b(pair({'a': t}, {'b': t}))
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testGraphModeWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
@ -83,6 +99,22 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
inputs = pair({'a': t}, {'b': t})
sq_op = function.make_defun_op(a_times_b, inputs)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(inputs)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include "tensorflow/python/lib/core/bfloat16.h"
#include "tensorflow/core/framework/numeric_types.h"
@ -477,8 +479,61 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) {
return true;
}
template <typename InType, typename OutType, typename Functor>
void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
void* data) {
const char* i0 = args[0];
const char* i1 = args[1];
char* o = args[2];
for (npy_intp k = 0; k < *dimensions; k++) {
InType x = *reinterpret_cast<const InType*>(i0);
InType y = *reinterpret_cast<const InType*>(i1);
*reinterpret_cast<OutType*>(o) = Functor()(x, y);
i0 += steps[0];
i1 += steps[1];
o += steps[2];
}
}
template <typename Functor>
void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps,
void* data) {
BinaryUFunc<bfloat16, npy_bool, Functor>(args, dimensions, steps, data);
}
struct Bfloat16EqFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
};
struct Bfloat16NeFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
};
struct Bfloat16LtFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
};
struct Bfloat16GtFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
};
struct Bfloat16LeFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
};
struct Bfloat16GeFunctor {
npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
};
// Initializes the module.
bool Initialize() {
// It's critical to import umath to avoid crash in open source build.
import_umath1(false);
Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
if (!numpy_str) {
return false;
}
Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get()));
if (!numpy) {
return false;
}
// We hit a mysterious crash if we haven't initialized numpy before this:
PyBfloat16_Type.tp_base = &PyGenericArrType_Type;
@ -536,6 +591,57 @@ bool Initialize() {
/*cast_is_safe=*/true)) {
return false;
}
// Register ufuncs
auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn,
const std::array<int, 3>& types) {
Safe_PyObjectPtr ufunc_obj =
make_safe(PyObject_GetAttrString(numpy.get(), name));
if (!ufunc_obj) {
return false;
}
PyUFuncObject* ufunc = reinterpret_cast<PyUFuncObject*>(ufunc_obj.get());
if (types.size() != ufunc->nargs) {
PyErr_Format(PyExc_AssertionError,
"ufunc %s takes %d arguments, loop takes %lu", name,
ufunc->nargs, types.size());
return false;
}
if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn,
const_cast<int*>(types.data()),
nullptr) < 0) {
return false;
}
return true;
};
// Comparisons
const std::array<int, 3> compare_types = {npy_bfloat16_, npy_bfloat16_,
NPY_BOOL};
if (!register_ufunc("equal", CompareUFunc<Bfloat16EqFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("not_equal", CompareUFunc<Bfloat16NeFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("less", CompareUFunc<Bfloat16LtFunctor>, compare_types)) {
return false;
}
if (!register_ufunc("greater", CompareUFunc<Bfloat16GtFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("less_equal", CompareUFunc<Bfloat16LeFunctor>,
compare_types)) {
return false;
}
if (!register_ufunc("greater_equal", CompareUFunc<Bfloat16GeFunctor>,
compare_types)) {
return false;
}
return true;
}

View File

@ -172,6 +172,24 @@ class Bfloat16NumPyTest(test.TestCase):
self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x))
self.assertAllEqual(x, x)
self.assertAllClose(x, x)
self.assertTrue((x == x).all())
def testComparisons(self):
x = np.array([401408, 7, -32], dtype=np.float32)
bx = x.astype(bfloat16)
y = np.array([82432, 7, 0], dtype=np.float32)
by = y.astype(bfloat16)
self.assertAllEqual(x == y, bx == by)
self.assertAllEqual(x != y, bx != by)
self.assertAllEqual(x < y, bx < by)
self.assertAllEqual(x > y, bx > by)
self.assertAllEqual(x <= y, bx <= by)
self.assertAllEqual(x >= y, bx >= by)
def testEqual2(self):
a = np.array([401408], bfloat16)
b = np.array([82432], bfloat16)
self.assertFalse(a.__eq__(b))
def testCasts(self):
for dtype in [

View File

@ -32,6 +32,7 @@ limitations under the License.
#include <Python.h>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
namespace tensorflow {

View File

@ -25,6 +25,7 @@ py_library(
":main_op",
":signature_constants",
":signature_def_utils",
":simple_save",
":tag_constants",
":utils",
"//tensorflow/python:util",
@ -89,6 +90,23 @@ py_library(
],
)
py_library(
name = "simple_save",
srcs = [
"simple_save.py",
],
srcs_version = "PY2AND3",
deps = [
":builder",
":signature_constants",
":signature_def_utils",
":tag_constants",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:lib",
"//tensorflow/python:util",
],
)
py_library(
name = "main_op",
srcs = [
@ -198,6 +216,22 @@ py_test(
],
)
py_test(
name = "simple_save_test",
size = "small",
srcs = ["simple_save_test.py"],
srcs_version = "PY2AND3",
deps = [
":loader",
":signature_constants",
":simple_save",
":tag_constants",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.

View File

@ -30,6 +30,9 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
# pylint: enable=unused-import
# pylint: disable=wildcard-import
from tensorflow.python.saved_model.simple_save import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@ -41,6 +44,7 @@ _allowed_symbols = [
"main_op",
"signature_constants",
"signature_def_utils",
"simple_save",
"tag_constants",
"utils",
]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SavedModel utility functions."""
"""SavedModel simple save functionality."""
from __future__ import absolute_import
from __future__ import division
@ -39,7 +39,7 @@ def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None):
to configure a SavedModel, this method has a few practical implications:
- It will be treated as a graph for inference / serving (i.e. uses the tag
`tag_constants.SERVING`)
- The saved model will load in TensorFlow Serving and supports the
- The SavedModel will load in TensorFlow Serving and supports the
[Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto).
To use the Classify, Regress, or MultiInference APIs, please
use either

Some files were not shown because too many files have changed in this diff Show More