diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ef7eb5a4d16..f258bcd9568 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -42,6 +42,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", ], }), ) @@ -73,6 +74,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 9b570470284..3aff4f91789 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -2618,4 +2619,54 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { + tensorflow::OpList op_list; + if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { + status->status = InvalidArgument("Unparseable OpList"); + return nullptr; + } + status->status = Status::OK(); + return new TF_ApiDefMap(op_list); +} + +void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } + +void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, + size_t text_len, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported in Android."); +#else + mutex_lock l(api_def_map->lock); + if (api_def_map->update_docs_called) { + status->status = FailedPrecondition( + "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " + "called."); + return; + } + string api_def_text(text, text_len); + status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); +#endif // __ANDROID__ +} + +TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, + size_t name_len, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported in Android."); + return nullptr; +#else + mutex_lock l(api_def_map->lock); + if (!api_def_map->update_docs_called) { + api_def_map->api_def_map.UpdateDocs(); + api_def_map->update_docs_called = true; + } + string name_str(name, name_len); + const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(*api_def, ret); + return ret; +#endif // __ANDROID__ +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index de9527f86d1..6f1c0606c11 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1518,6 +1518,49 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // in this address space. TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +// TF_ApiDefMap encapsulates a collection of API definitions for an operation. +// +// This object maps the name of a TensorFlow operation to a description of the +// API to generate for it, as defined by the ApiDef protocol buffer ( +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) +// +// The ApiDef messages are typically used to generate convenience wrapper +// functions for TensorFlow operations in various language bindings. +typedef struct TF_ApiDefMap TF_ApiDefMap; + +// Creates a new TF_ApiDefMap instance. +// +// Params: +// op_list_buffer - TF_Buffer instance containing serialized OpList +// protocol buffer. (See +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for the OpList proto definition). +// status - Set to OK on success and an appropriate error on failure. +TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, + TF_Status* status); + +// Deallocates a TF_ApiDefMap. +TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); + +// Add ApiDefs to the map. +// +// `text` corresponds to a text representation of an ApiDefs protocol message. +// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). +// +// The provided ApiDefs will be merged with existing ones in the map, with +// precedence given to the newly added version in case of conflicts with +// previous calls to TF_ApiDefMapPut. +TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, + const char* text, size_t text_len, + TF_Status* status); + +// Returns a serialized ApiDef protocol buffer for the TensorFlow operation +// named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, + const char* name, + size_t name_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 6df77a7f9ba..f8edc90a9f9 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -24,6 +24,9 @@ limitations under the License. #include #include +#ifndef __ANDROID__ +#include "tensorflow/core/framework/op_gen_lib.h" +#endif #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -158,6 +161,22 @@ struct TF_Function { tensorflow::FunctionDef fdef; }; +struct TF_ApiDefMap { + explicit TF_ApiDefMap(const tensorflow::OpList& op_list) + : +#ifndef __ANDROID__ + api_def_map(op_list), +#endif + update_docs_called(false) { + } + +#ifndef __ANDROID__ + tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); +#endif + bool update_docs_called GUARDED_BY(lock); + tensorflow::mutex lock; +}; + namespace tensorflow { class TensorCApi { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 4e89b4fc439..df697e16d3d 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" @@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); } +TEST(TestApiDef, TestCreateApiDef) { + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + +TEST(TestApiDef, TestCreateApiDefWithOverwrites) { + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string api_def_overwrites = R"(op: < + graph_op_name: "TestCApi" + summary: "New summary" +> +)"; + status = TF_NewStatus(); + TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(), + api_def_overwrites.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ("New summary", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + #undef EXPECT_TF_META } // namespace diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d533758e360..53df884e7ca 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -33,7 +33,7 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], - }), + }) + ["//tensorflow/core:gpu_runtime"], ) tf_cuda_library( @@ -55,6 +55,10 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], + tags = [ + "guitar", + "multi_gpu", + ], deps = [ ":c_api", "//tensorflow/core:lib", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 706c89536db..beffa191d16 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" +#include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -167,18 +168,6 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, if (is_same_device) { return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); } - const bool src_cpu = IsCPU(srcd); - if (src_cpu == dst_cpu) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat( - "TFE_TensorHandleCopyToDevice requires either the source " - "TFE_TensorHandle be on or the destination device be on CPU " - "or be the same (they are ", - DeviceName(srcd), " and ", DeviceName(dstd), " in this call)") - .c_str()); - return nullptr; - } tensorflow::Tensor* src = &(h->t); if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) { TF_SetStatus( @@ -189,26 +178,19 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, .c_str()); return nullptr; } - if (src_cpu) { - tensorflow::Tensor dst( - dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(), - src->shape()); - if (src->shape().num_elements() == 0) { - return new TFE_TensorHandle(dst, dstd); - } - tensorflow::Notification n; - dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice( - src, dstd, &dst, [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); - n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd) - : nullptr; + tensorflow::Tensor dst(dstd->GetAllocator(tensorflow::AllocatorAttributes()), + src->dtype(), src->shape()); + if (src->shape().num_elements() == 0) { + return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); + } + tensorflow::DeviceContext* src_device_context = nullptr; + if (!IsCPU(srcd)) { + src_device_context = srcd->tensorflow_gpu_device_info()->default_context; + } + tensorflow::DeviceContext* dst_device_context = nullptr; + if (!dst_cpu) { + dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; } - CHECK(dst_cpu); - tensorflow::Tensor dst(src->dtype(), src->shape()); - tensorflow::Notification n; // TODO(ashankar): The Sync() call below may be more aggressive than // necessary. It is based on knowledge of implementation details - that // GPU devices are implemented using 3 streams - one for host->device copies, @@ -217,16 +199,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, // but more than necessary (since it waits for operations that might have // nothing to do with this tensor to complete). status->status = srcd->Sync(); - if (!status->status.ok()) return nullptr; - srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU( - src, "IGNORE_MY_TENSOR_NAME", srcd, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); + tensorflow::Notification n; + tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, + srcd, dstd, tensorflow::AllocatorAttributes(), + tensorflow::AllocatorAttributes(), src, &dst, + [status, &n](const tensorflow::Status& s) { + status->status = s; + n.Notify(); + }); n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr) - : nullptr; + return (TF_GetCode(status) == TF_OK) + ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) + : nullptr; } TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3fe0b7efa11..c5ec0cfc31d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -216,6 +216,64 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + const char* kCPUDevice = "CPU:0"; + if (num_devices < 3) { + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + return; + } + const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + const string gpu_2_name(TF_DeviceListName(devices, 2, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + + TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice( + hdevice, ctx, gpu_2_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_DeleteTensorHandle(hdevice); + // Copy back to CPU + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_DeleteTensorHandle(hdevice2); + + // Ensure that the contents are the same! + TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get()); + TFE_DeleteTensorHandle(hcopy); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)); + EXPECT_EQ( + 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t))); + TF_DeleteTensor(tcopy); + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, TensorHandleSilentCopy) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 026a1bf879d..2374620f583 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -248,6 +248,7 @@ cc_library( "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 407b7dcbfb4..0de163d3a8f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" @@ -102,16 +103,26 @@ void MarkGuaranteedConstants( // A node/slot pair. // TODO(phawkins): is there a common definition of this? struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1) {} - NodeSlot(const Node* node, int slot) : node(node), slot(slot) {} + NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} + NodeSlot(const Node* node, int slot) + : node(node), slot(slot), dtype(DT_INVALID) {} + NodeSlot(const Node* node, int slot, DataType dtype) + : node(node), slot(slot), dtype(dtype) {} const Node* node; int slot; + // Optional: used to record the destination type of a source NodeSlot in case + // the source output is a Ref type that is cast to a Tensor at the + // destination. + DataType dtype; + bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot; + return node == other.node && slot == other.slot && dtype == other.dtype; } + // Leave dtype out of the hash since there are never two NodeSlots with the + // same node and slot and different dtypes. struct Hasher { uint64 operator()(NodeSlot const& s) const { return Hash64Combine(std::hash()(s.node), @@ -130,11 +141,19 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; +static const char* const kSendToHostOp = "_XlaSendToHost"; +static const char* const kRecvFromHostOp = "_XlaRecvFromHost"; +static const char* const kSendFromHostOp = "_XlaSendFromHost"; +static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; class Encapsulator { public: - Encapsulator(string group_attribute, Graph const* graph_in) - : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} + Encapsulator(string group_attribute, string outside_compilation_attribute, + Graph const* graph_in) + : group_attribute_(std::move(group_attribute)), + outside_compilation_attribute_( + std::move(outside_compilation_attribute)), + graph_in_(graph_in) {} // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. @@ -156,7 +175,47 @@ class Encapsulator { private: // A subgraph of the input, all marked with a common 'group_attribute' - // value. + // value. A subgraph may contain multiple `outside_compilation' clusters. + // + // In the following simple example, A, B, ..., E are nodes in the original + // graph. The group attributes and outside_compilation attributes g and oc are + // each shown as either 0 or empty. + // + // A --> B --> C --> D --> E + // g: g:0 g:0 g:0 g: + // oc: oc: oc:0 oc: oc: + // + // The example is rewritten to two graphs; one on the host and one to be + // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving + // input from the compiled cluster, and SFH is a SendFromHost node sending + // input back to the compiled cluster. Dotted edges are control edges. A + // 'sequencing' node S is inserted, and both RAH and SFH are connected via S + // to E (and in general all nodes that depend on nodes in the compiled + // cluster) to ensure that they are not pruned. + // + // A --> Call --> E + // ^ + // . + // ........> S + // .... ^ + // .. . + // RAH --> C --> SFH + // + // The compiled cluster is as follows. STH is a SendToHost node which is the + // source of a channel to the RAH node above. RFH is a RecvFromHost node which + // is the destination of a channel from the SFH node above. There is a control + // edge that ensures RFH follows STH, which is used in shape inference to + // ensure that the shapes on the STH host channel are known before the RFH + // channel is compiled. + // + // Arg --> B --> STH ..> RFH --> D --> Retval + // + // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most + // one RAH and SFH in each compiled cluster. This design is preferred over + // adding separate Arg/Retval nodes for each transmitted value because it + // simplifies the host code that would like to limit communication between + // host and device and, e.g., raise only one interrupt per channel rather than + // one per transmitted value. class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -181,6 +240,12 @@ class Encapsulator { const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out); + // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. + Status AddOutsideCompilationHostIONodes( + const string& subgraph_name, + const std::unordered_map& node_images, + Graph* graph_out); + // Returns the Node that inputs to the function should be wired up to. Node* GetCallNodeForInputs() const; @@ -193,6 +258,24 @@ class Encapsulator { // Returns the index of the result that the src of edge should connect to. int GetResultIndexForEdge(const Edge* edge) const; + // Returns the RecvAtHost node for an outside_compilation subgraph. + Node* GetRecvAtHostNode( + const string& outside_compilation_subgraph_name) const; + + // Returns the output slot for the RecvAtHost node that corresponds to the + // source of edge in an outside_compilation subgraph. + int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name, + const Edge* edge) const; + + // Returns the SendFromHost node for an outside_compilation subgraph. + Node* GetSendFromHostNode( + const string& outside_compilation_subgraph_name) const; + + // Returns the input slot for the SendFromHost node that corresponds to the + // destination of edge in an outside_compilation subgraph. + int GetSendFromHostSlot(const string& outside_compilation_subgraph_name, + const Edge* edge) const; + // Creates an _Arg node for the src node of edge, and add its index to // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, // and adds the edge within the subgraph from the _Arg node to the image of @@ -208,13 +291,102 @@ class Encapsulator { const Edge* edge, const std::unordered_map& node_images); + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Creates an entry for the src node of edge in the list of + // inputs for the outside_compilation subgraph, if none exists yet. + void RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge); + + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Creates an entry for the src node of edge in the list of + // outputs by src for the outside_compilation subgraph, if none exists + // yet. Creates an entry for the dst node of edge in the list of outputs by + // dst for the outside_compilation subgraph. + void RecordOutsideCompilationOutputOrControl( + const string& outside_compilation_id, const Edge* edge); + + // Adds the SendToHost nodes for each outside_compilation subgraph once the + // edges have all been recorded via RecordOutsideCompilationInputOrControl. + Status AddSendsToOutsideCompilation( + const std::unordered_map& node_images); + + // Adds the RecvFromHost nodes for each outside_compilation subgraph once + // the edges have all been recorded via + // RecordOutsideCompilationOutputOrControl. + Status AddRecvsFromOutsideCompilation( + const std::unordered_map& node_images); + + // Creates the sequencer node if it doesn't exist, adding it to graph_out. + Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); + + // If there is a sequencer node, adds a control edge from the sequencer to + // all the downstream nodes of call_node_outputs. + void ConnectSequencerToOutputs(Graph* graph_out); + private: + struct OutsideCompilationSubgraph { + // Map from source (producer node/slot) tensors in the original graph to + // input index (slot number in the SendToHost/RecvAtHost nodes that will + // be created) for the outside_compilation subgraph. + std::unordered_map inputs; + + // Set of nodes in the original graph that are the source of control edges + // that cross from the containing compiled subgraph into the + // outside_compilation subgraph. These are recorded by + // RecordOutsideCompilationInputOrControl while walking all the subgraph + // edges, and lifted control edges within the subgraph are added by + // AddSendsToOutsideCompilation once the _SendToHost node has been + // created. The matching control edge from _RecvAtHost to the + // destination is added by CopyEdgeToOutputGraph. + std::unordered_set control_inputs; + + // Maps from source (producer node/slot) and destination (consumer + // node/slot) tensors in the original graph to output index (slot number + // in the SendFromHost/RecvFromHost nodes that will be created) for the + // outside_compilation subgraph. + std::unordered_map outputs_by_src; + std::unordered_map outputs_by_dst; + + // Set of nodes in the original graph that are the destination of control + // edges that cross from the outside_compilation subgraph into the + // containing compiled subgraph. These are recorded by + // RecordOutsideCompilationOutputOrControl while walking all the subgraph + // edges, and lifted control edges within the subgraph are added by + // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been + // created. The matching control edge from the source to _SendFromHost to + // the destination is added by CopyEdgeToOutputGraph. + std::unordered_set control_outputs; + + // _SendToHost node in the subgraph. Not owned. + Node* send_to_host = nullptr; + + // _RecvAtHost node in the output graph. Not owned. + Node* recv_at_host = nullptr; + + // _SendFromHost node in the output graph. Not owned. + Node* send_from_host = nullptr; + }; + // Builds a ParallelCheck op that compares the output of the original // subgraph with the encapsulated subgraph. Status BuildParallelCheckOp( const std::unordered_map& node_images, Graph* graph_out); + // Builds a _RecvAtHost node producing all the inputs of an + // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. + Status AddRecvAtHostNode(const string& subgraph_name, + const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); + + // Builds a _SendFromHost node consuming all the outputs of an + // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. + Status AddSendFromHostNode( + const std::unordered_map& node_images, + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are // returned by _Retval nodes. @@ -247,19 +419,33 @@ class Encapsulator { // Map from source tensor in the input graph to result #. std::unordered_map results_; + + // The outside_compilation clusters in this subgraph. + std::unordered_map + outside_compilation_subgraphs_; + + // NoOp node in the output graph that is sequenced after the call node and + // used to prevent host-side outside_compilation sends and recvs from being + // pruned. + Node* sequencer_ = nullptr; }; - // Returns the key attribute associated with a node in attr. Sets attr to the - // empty string if the attribute is not found. - Status GetFunctionNameAttr(const Node* node, string* attr) const; + // Returns the key attribute and outside_compilation attribute associated + // with a node in attr, and outside_compilation_attr, respectively. Sets + // either result to the empty string if the respective attribute is not + // found. Returns error status if there is an outside_compilation attribute + // and no key attribute, + Status GetFunctionNameAttr(Node const* node, string* attr, + string* outside_compilation_attr) const; - // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to subgraphs - // for data edges that cross subgraph boundaries. + // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to + // subgraphs for data edges that cross subgraph boundaries. Status CopySubgraphEdges( const std::unordered_map& node_images, std::vector>* src_arg_pairs); - // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes. + // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes, + // or nodes marked outside_compilation. Status CopySubgraphNodes(std::unordered_map* node_images); // Copies all nodes that aren't in a compiled subgraph to the output graph. @@ -272,38 +458,54 @@ class Encapsulator { const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out); + // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all + // outside_compilation subgraphs. + Status AddOutsideCompilationHostIONodes( + const std::unordered_map& node_images, + Graph* graph_out); + // Finds the image of an edge source in the output graph. If the edge crosses // a subgraph boundary it is the output of a call node, otherwise it is a node // in the output graph. Status FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& dst_func_id, + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, const Node* original_src_node, Node** src_image); // Finds an edge source slot in the output graph. If the edge crosses a - // subgraph boundary it is a slot on the output of a call node, otherwise it - // is a slot on a node in the output graph. + // subgraph boundary it is a slot on the output of a call node or a + // _RecvAtHost node, otherwise it is a slot on a node in the output graph. int FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& dst_func_id, const Edge* edge); + const string& src_outside_compilation_id, + const string& dst_func_id, + const string& dst_outside_compilation_id, + const Edge* edge); // Finds the image of an edge destination in the output graph. If the edge - // crosses a subgraph boundary it is the input of a call node, otherwise it is - // a node in the output graph. + // crosses a subgraph boundary it is the input of a call node or a + // _SendFromHost node, otherwise it is a node in the output graph. Status FindOutputImageOfEdgeDst( - const string& src_func_id, const string& dst_func_id, + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, const Node* original_dst_node, Node** dst_image); // Finds an edge destination slot in the output graph. If the edge crosses a - // subgraph boundary it is a slot on the input of a call node, otherwise it is - // a slot on a node in the output graph. + // subgraph boundary it is a slot on the input of a call node or a + // _SendFromHost node, otherwise it is a slot on a node in the output graph. int FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& dst_func_id, const Edge* edge); + const string& src_outside_compilation_id, + const string& dst_func_id, + const string& dst_outside_compilation_id, + const Edge* edge); // Copies a single edge to the output graph. The edge is either entirely // within the output graph, or crosses into or out of a compiled subgraph. Status CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, const string& dst_func_id, + const Edge* edge, const string& src_func_id, + const string& src_outside_compilation_id, const string& dst_func_id, + const string& dst_outside_compilation_id, const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out, std::unordered_set, NodeSlot::PairHasher>* @@ -339,6 +541,30 @@ int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { return results_.at(NodeSlot(edge->src(), edge->src_output())); } +Node* Encapsulator::Subgraph::GetRecvAtHostNode( + const string& outside_compilation_subgraph_name) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .recv_at_host; +} + +int Encapsulator::Subgraph::GetRecvAtHostSlot( + const string& outside_compilation_subgraph_name, const Edge* edge) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .inputs.at(NodeSlot(edge->src(), edge->src_output())); +} + +Node* Encapsulator::Subgraph::GetSendFromHostNode( + const string& outside_compilation_subgraph_name) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .send_from_host; +} + +int Encapsulator::Subgraph::GetSendFromHostSlot( + const string& outside_compilation_subgraph_name, const Edge* edge) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); +} + Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { if (!graph_) { graph_.reset(new Graph(graph_in->op_registry())); @@ -367,19 +593,10 @@ Status Encapsulator::Subgraph::RecordArg( args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); int arg_index = iter->second; if (inserted) { - // Look at the type of the destination not the source, since Ref output - // Tensors can be automatically cast to non-Ref Tensors at the destination. - DataType dtype = edge->dst()->input_type(edge->dst_input()); - - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported as args: tensor ", - src_node->name(), ":", src_slot); - } - NodeDef arg_def; NodeDefBuilder builder( strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); Status s = builder.Finalize(&arg_def); @@ -411,17 +628,10 @@ Status Encapsulator::Subgraph::RecordResult( results_.emplace(NodeSlot(src_node, src_slot), results_.size()); int ret_index = iter->second; if (inserted) { - DataType dtype = src_node->output_type(src_slot); - - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported as results: tensor ", - src_node->name(), ":", src_slot); - } - NodeDef ret_def; NodeDefBuilder builder( strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); builder.Input(src_image->name(), src_slot, dtype); @@ -435,6 +645,184 @@ Status Encapsulator::Subgraph::RecordResult( return Status::OK(); } +void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge) { + auto iter = outside_compilation_subgraphs_ + .emplace(outside_compilation_id, OutsideCompilationSubgraph()) + .first; + OutsideCompilationSubgraph& outside_subgraph = iter->second; + if (edge->IsControlEdge()) { + outside_subgraph.control_inputs.insert(edge->src()); + } else { + int input_index = outside_subgraph.inputs.size(); + outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), + input_index); + } +} + +void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( + const string& outside_compilation_id, const Edge* edge) { + auto subgraph_iter = + outside_compilation_subgraphs_ + .emplace(outside_compilation_id, OutsideCompilationSubgraph()) + .first; + OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; + if (edge->IsControlEdge()) { + outside_subgraph.control_outputs.insert(edge->dst()); + } else { + DataType dtype = edge->dst()->input_type(edge->dst_input()); + auto output_iter = + outside_subgraph.outputs_by_src + .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), + outside_subgraph.outputs_by_src.size()) + .first; + int output_index = output_iter->second; + outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + output_index; + } +} + +Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( + const std::unordered_map& node_images) { + for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { + const string& oc_subgraph_name = oc_subgraph_iter.first; + OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { + // Build a _SendToHost node sending all the args of the appropriate + // types. + std::vector dtypes(oc_subgraph.inputs.size(), DT_INVALID); + std::vector inputs(oc_subgraph.inputs.size()); + + for (const auto& input_src : oc_subgraph.inputs) { + const Node* src_node = input_src.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = input_src.first.slot; + int input_index = input_src.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[input_index] = dtype; + inputs[input_index].Reset(src_image->name(), src_slot, dtype); + } + + NodeDef send_def; + NodeDefBuilder builder( + strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"), + kSendToHostOp); + builder.Attr("dtypes", dtypes); + builder.Input(inputs); + Status s = builder.Finalize(&send_def); + if (!s.ok()) return s; + + oc_subgraph.send_to_host = graph_->AddNode(send_def, &s); + if (!s.ok()) return s; + + // Connect the _SendToHost node to its producers in the subgraph. + for (auto& input_src : oc_subgraph.inputs) { + const Node* src_node = input_src.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = input_src.first.slot; + int input_index = input_src.second; + graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host, + input_index); + } + + // Connect the _SendToHost node to its control edge producers in the + // subgraph. + for (const auto& src_node : oc_subgraph.control_inputs) { + Node* src_image = node_images.at(src_node); + graph_->AddControlEdge(src_image, oc_subgraph.send_to_host); + } + } + } + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation( + const std::unordered_map& node_images) { + for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { + const string& oc_subgraph_name = oc_subgraph_iter.first; + OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + if (!oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + // Build a _RecvFromHost node producing all the outputs of the appropriate + // types. + std::vector dtypes(oc_subgraph.outputs_by_src.size(), + DT_INVALID); + + for (const auto& output : oc_subgraph.outputs_by_src) { + DataType dtype = output.first.dtype; + int output_index = output.second; + dtypes[output_index] = dtype; + } + + NodeDef recv_def; + NodeDefBuilder builder( + strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"), + kRecvFromHostOp); + builder.Attr("dtypes", dtypes); + Status s = builder.Finalize(&recv_def); + if (!s.ok()) return s; + + Node* recv = graph_->AddNode(recv_def, &s); + if (!s.ok()) return s; + + // Connect the consumers in the subgraph to the _RecvFromHost node. + for (const auto& output : oc_subgraph.outputs_by_dst) { + const Node* dst_node = output.first.node; + Node* dst_image = node_images.at(dst_node); + int dst_slot = output.first.slot; + int output_index = output.second; + + graph_->AddEdge(recv, output_index, dst_image, dst_slot); + } + + // Connect the control edge consumers in the subgraph to the _RecvFromHost + // node. + for (const auto& dst_node : oc_subgraph.control_outputs) { + Node* dst_image = node_images.at(dst_node); + graph_->AddControlEdge(recv, dst_image); + } + + // Add a control edge in the subgraph so that the _SendToHost node, if + // any, is compiled before the _RecvFromHost node. + if (oc_subgraph.send_to_host != nullptr) { + graph_->AddControlEdge(oc_subgraph.send_to_host, recv); + } + } + } + + return Status::OK(); +} + +Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, + Graph* graph_out) { + if (sequencer_ == nullptr) { + NodeDef seq_def; + NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), + "NoOp"); + Status s = builder.Finalize(&seq_def); + if (!s.ok()) return s; + + sequencer_ = graph_out->AddNode(seq_def, &s); + if (!s.ok()) return s; + sequencer_->set_assigned_device_name(device_); + } + return Status::OK(); +} + +void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { + if (sequencer_ != nullptr) { + std::unordered_set output_dependencies; + for (Node* node : call_node_outputs_->out_nodes()) { + output_dependencies.insert(node); + } + for (Node* node : output_dependencies) { + graph_out->AddControlEdge(sequencer_, node); + } + } +} + Status Encapsulator::Subgraph::BuildFunctionDef( const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { @@ -574,24 +962,144 @@ Status Encapsulator::Subgraph::AddFunctionCallNode( return Status::OK(); } -Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { +Status Encapsulator::Subgraph::AddRecvAtHostNode( + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); + + for (const auto& input : oc_subgraph->inputs) { + const Node* src_node = input.first.node; + int src_slot = input.first.slot; + int input_index = input.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[input_index] = dtype; + } + + NodeDef recv_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), + kRecvAtHostOp); + builder.Attr("dtypes", dtypes); + Status s = builder.Finalize(&recv_def); + if (!s.ok()) return s; + + oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); + if (!s.ok()) return s; + oc_subgraph->recv_at_host->set_assigned_device_name(device_); + + // Add a control dependency forcing the RecvAtHost to run before the subgraph + // completes. This has no effect on execution order but prevents the + // RecvAtHost being pruned. + TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddSendFromHostNode( + const std::unordered_map& node_images, + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); + std::vector inputs( + oc_subgraph->outputs_by_src.size()); + + for (const auto& output : oc_subgraph->outputs_by_src) { + const Node* src_node = output.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = output.first.slot; + int output_index = output.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[output_index] = dtype; + inputs[output_index].Reset(src_image->name(), src_slot, dtype); + } + + NodeDef send_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), + kSendFromHostOp); + builder.Attr("dtypes", dtypes); + builder.Input(inputs); + Status s = builder.Finalize(&send_def); + if (!s.ok()) return s; + + oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); + if (!s.ok()) return s; + oc_subgraph->send_from_host->set_assigned_device_name(device_); + + // Add a control dependency forcing the SendFromHost to run before the + // subgraph completes. This has no effect on execution order but prevents the + // RecvAtHost being pruned. + TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( + const string& subgraph_name, + const std::unordered_map& node_images, + Graph* graph_out) { + for (auto& outside_compilation_subgraph_entry : + outside_compilation_subgraphs_) { + const string& oc_name = outside_compilation_subgraph_entry.first; + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraph_entry.second; + + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { + TF_RETURN_IF_ERROR( + AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); + } + + if (!oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, + oc_name, &oc_subgraph, graph_out)); + } + } + return Status::OK(); +} + +Status Encapsulator::GetFunctionNameAttr( + Node const* node, string* attr, string* outside_compilation_attr) const { Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); if (s.code() == error::Code::NOT_FOUND) { // Return empty attr if there's no group_attribute. attr->clear(); - return Status::OK(); + } else { + TF_RETURN_IF_ERROR(s); } - return s; + bool has_group_attr = s.ok(); + s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, + outside_compilation_attr); + if (s.code() == error::Code::NOT_FOUND) { + // Return empty attr if there's no outside_compilation attribute. + outside_compilation_attr->clear(); + } else { + TF_RETURN_IF_ERROR(s); + if (!has_group_attr) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } + } + return Status::OK(); } -bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } +bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { + return !func_id.empty() && outside_compilation_id.empty(); +} Status Encapsulator::CopySubgraphNodes( std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); - if (!IsInSubgraph(func_id)) continue; + string outside_compilation_id; + TF_RETURN_IF_ERROR( + GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); + if (!IsInSubgraph(func_id, outside_compilation_id)) continue; Subgraph& subgraph = subgraphs_[func_id]; Node* image = subgraph.MakeNodeImage(graph_in_, node); @@ -606,14 +1114,19 @@ Status Encapsulator::CopySubgraphEdges( std::vector>* src_arg_pairs) { for (const Edge* edge : graph_in_->edges()) { string src_func_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); + string src_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, + &src_outside_compilation_id)); string dst_func_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); + string dst_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, + &dst_outside_compilation_id)); Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); // Copy edges that are local to a subgraph. - if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) && + if (IsInSubgraph(src_func_id, src_outside_compilation_id) && + IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { Graph* g = subgraphs_[src_func_id].GetGraph(); if (edge->IsControlEdge()) { @@ -625,23 +1138,60 @@ Status Encapsulator::CopySubgraphEdges( } // Record 'src' as an output of its subgraph, if applicable. - if (IsInSubgraph(src_func_id)) { - Subgraph& src_subgraph = subgraphs_[src_func_id]; - // Ignore control edges leaving the subgraph. We will lift them onto the - // enclosing call operators in BuildOutputGraph(). + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { if (!edge->IsControlEdge()) { - TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); + DataType dtype = edge->src()->output_type(edge->src_output()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as results: " + "tensor ", + edge->src()->name(), ":", edge->src_output()); + } + } + + Subgraph& src_subgraph = subgraphs_[src_func_id]; + if (src_func_id == dst_func_id) { + // src is in the subgraph and dst is outside_compilation in the same + // subgraph. + src_subgraph.RecordOutsideCompilationInputOrControl( + dst_outside_compilation_id, edge); + } else { + // Ignore control edges leaving the subgraph. We will lift them onto the + // enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); + } } } // Record 'dst' as an input of its subgraph, if applicable. - if (IsInSubgraph(dst_func_id)) { - Subgraph& dst_subgraph = subgraphs_[dst_func_id]; - // Ignore control edges entering the subgraph. We will lift them onto - // the enclosing call operators in BuildOutputGraph(). + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + // Look at the type of the destination not the source, since Ref output + // Tensors can be automatically cast to non-Ref Tensors at the + // destination. if (!edge->IsControlEdge()) { - TF_RETURN_IF_ERROR( - dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); + DataType dtype = edge->dst()->input_type(edge->dst_input()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as args: " + "tensor ", + edge->src()->name(), ":", edge->src_output()); + } + } + + Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + if (src_func_id == dst_func_id) { + // dst is in the subgraph and src is outside_compilation in the same + // subgraph. + dst_subgraph.RecordOutsideCompilationOutputOrControl( + src_outside_compilation_id, edge); + } else { + // Ignore control edges entering the subgraph. We will lift them onto + // the enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR( + dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); + } } } } @@ -663,6 +1213,17 @@ Status Encapsulator::SplitIntoSubgraphs() { TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images)); TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs)); + // For each subgraph, add the nodes that deal with inputs and outputs its + // nested outside_compilation subgraphs. These could not be added earlier + // during CopySubgraphEdges since we need to discover all the types of the + // inputs and outputs for an outside_compilation subgraph before creating a + // single input and output node for it. + for (auto& entry : subgraphs_) { + Subgraph& subgraph = entry.second; + TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images)); + TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images)); + } + MarkGuaranteedConstants(*graph_in_, src_arg_pairs); for (auto& entry : subgraphs_) { @@ -690,13 +1251,25 @@ Status Encapsulator::CopyNodesToOutputGraph( std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); + string outside_compilation_id; + TF_RETURN_IF_ERROR( + GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - // Don't copy nodes that are going to be encapsulated, unless parallel - // checking is enabled. - if (IsInSubgraph(func_id) && !parallel_checking) continue; + // Don't copy nodes that going to be encapsulated, unless parallel checking + // is enabled. + if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) + continue; Node* image = graph_out->CopyNode(node); + if (!outside_compilation_id.empty()) { + if (parallel_checking) { + return errors::InvalidArgument( + "Parallel checking is not supported when outside_compilation " + "clusters are present."); + } + image->ClearAttr(group_attribute_); + image->ClearAttr(outside_compilation_attribute_); + } (*node_images)[node] = image; } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); @@ -714,14 +1287,36 @@ Status Encapsulator::AddFunctionCallNodes( return Status::OK(); } +Status Encapsulator::AddOutsideCompilationHostIONodes( + const std::unordered_map& node_images, + Graph* graph_out) { + for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; + Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( + subgraph_name, node_images, graph_out)); + } + return Status::OK(); +} + Status Encapsulator::FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& dst_func_id, + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, const Node* original_src_node, Node** src_image) { - if (IsInSubgraph(src_func_id)) { - // The edge is from a subgraph to a regular node in the output graph so - // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + if (dst_func_id == src_func_id) { + // The edge is from a subgraph to an outside_compilation cluster in the + // same subgraph so use the appropriate _RecvAtHost node in the output + // graph. + TF_RET_CHECK(!dst_outside_compilation_id.empty()); + *src_image = subgraphs_.at(src_func_id) + .GetRecvAtHostNode(dst_outside_compilation_id); + } else { + // The edge is from a subgraph to a regular node in the output graph so + // use the subgraph's call node output. + *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + } } else { // The source of the edge is in the output graph so use the node image in // the output graph. @@ -730,14 +1325,21 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( return Status::OK(); } -int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& dst_func_id, - const Edge* edge) { - if (IsInSubgraph(src_func_id)) { +int Encapsulator::FindOutputSlotOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const Edge* edge) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { const Subgraph& src_subgraph = subgraphs_.at(src_func_id); - // 'src' is in a subgraph and 'dst' is a regular node in the output - // graph. Use the corresponding call output instead. - return src_subgraph.GetResultIndexForEdge(edge); + if (src_func_id == dst_func_id) { + // 'src' is in a subgraph and 'dst' is outside_compilation in the same + // subgraph. Use the corresponding _RecvAtHost output instead. + return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge); + } else { + // 'src' is in a subgraph and 'dst' is a regular node in the output + // graph. Use the corresponding call output instead. + return src_subgraph.GetResultIndexForEdge(edge); + } } else { // The source of the edge is in the output graph so use the regular edge // slot. @@ -746,13 +1348,23 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, } Status Encapsulator::FindOutputImageOfEdgeDst( - const string& src_func_id, const string& dst_func_id, + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, const Node* original_dst_node, Node** dst_image) { - if (IsInSubgraph(dst_func_id)) { - // The edge is to a subgraph from a regular node in the output graph so - // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + if (src_func_id == dst_func_id) { + // The edge is to a subgraph from an outside_compilation cluster in the + // same subgraph so use the appropriate _SendFromHost node in the output + // graph. + TF_RET_CHECK(!src_outside_compilation_id.empty()); + *dst_image = subgraphs_.at(dst_func_id) + .GetSendFromHostNode(src_outside_compilation_id); + } else { + // The edge is to a subgraph from a regular node in the output graph so + // use the subgraph's call node input. + *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + } } else { // The destination of the edge is in the output graph so use the node image // in the output graph. @@ -761,14 +1373,21 @@ Status Encapsulator::FindOutputImageOfEdgeDst( return Status::OK(); } -int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& dst_func_id, - const Edge* edge) { - if (IsInSubgraph(dst_func_id)) { +int Encapsulator::FindOutputSlotOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const Edge* edge) { + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); - // 'dst' is in a subgraph and 'src' is a regular node in the output - // graph. Use the corresponding call input instead. - return dst_subgraph.GetArgIndexForEdge(edge); + if (dst_func_id == src_func_id) { + // 'dst' is in a subgraph and 'src' is outside_compilation in the same + // subgraph. Use the corresponding _SendFromHost input instead. + return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge); + } else { + // 'dst' is in a subgraph and 'src' is a regular node in the output + // graph. Use the corresponding call input instead. + return dst_subgraph.GetArgIndexForEdge(edge); + } } else { // The destination of the edge is in the output graph so use the regular // edge slot. @@ -777,17 +1396,21 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, } Status Encapsulator::CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, const string& dst_func_id, + const Edge* edge, const string& src_func_id, + const string& src_outside_compilation_id, const string& dst_func_id, + const string& dst_outside_compilation_id, const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out, std::unordered_set, NodeSlot::PairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( - src_func_id, dst_func_id, node_images, edge->src(), &src_image)); + src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, edge->src(), &src_image)); Node* dst_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst( - src_func_id, dst_func_id, node_images, edge->dst(), &dst_image)); + src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, edge->dst(), &dst_image)); // If this is a control edge then copy it and return. Lift control edges onto // the enclosing call operator. @@ -807,11 +1430,16 @@ Status Encapsulator::CopyEdgeToOutputGraph( return Status::OK(); } - int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge); + int src_output = + FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id, + dst_func_id, dst_outside_compilation_id, edge); - int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge); + int dst_input = + FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, + dst_func_id, dst_outside_compilation_id, edge); - if (IsInSubgraph(dst_func_id) && parallel_checking) { + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + parallel_checking) { // If we are parallel checking, also feed the tensor as an input to the // corresponding parallel check subgraph. graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), @@ -839,13 +1467,18 @@ Status Encapsulator::AddEdgesToOutputGraph( for (const Edge* edge : graph_in_->edges()) { string src_func_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); + string src_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, + &src_outside_compilation_id)); string dst_func_id; - TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); + string dst_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, + &dst_outside_compilation_id)); // Ignore edges that are strictly contained within one subgraph, unless // we are constructing parallel check graphs. - if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) && + if (IsInSubgraph(src_func_id, src_outside_compilation_id) && + IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { if (parallel_checking) { Node* src_image = node_images.at(edge->src()); @@ -862,9 +1495,15 @@ Status Encapsulator::AddEdgesToOutputGraph( // We have an edge that crosses a cluster boundary or is entirely within the // unclustered graph. - TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(edge, src_func_id, dst_func_id, - node_images, parallel_checking, - graph_out, &edges_added)); + TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( + edge, src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, parallel_checking, graph_out, + &edges_added)); + } + + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + subgraph.ConnectSequencerToOutputs(graph_out); } return Status::OK(); @@ -879,6 +1518,7 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); TF_RETURN_IF_ERROR( AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); TF_RETURN_IF_ERROR( AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); @@ -888,13 +1528,15 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, } // anonymous namespace Status EncapsulateSubgraphsInFunctions( - string group_attribute, const Graph& graph_in, - const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, - bool reuse_existing_functions, std::unique_ptr* graph_out, - FunctionLibraryDefinition* library) { + string group_attribute, string outside_compilation_attribute, + const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool parallel_checking, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { Status s; - Encapsulator encapsulator(std::move(group_attribute), &graph_in); + Encapsulator encapsulator(std::move(group_attribute), + std::move(outside_compilation_attribute), + &graph_in); TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( @@ -1021,8 +1663,8 @@ Status EncapsulateSubgraphsPass::Run( }; TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index b0987f76c91..34be4409a38 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -48,6 +48,16 @@ typedef std::function* graph_out, - FunctionLibraryDefinition* library); + string group_attribute, string outside_compilation_attribute, + const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool parallel_checking, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 717efb36018..b100861d5e9 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -36,7 +36,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (diff) { *diff = strings::StrCat("Definition mismatch for function ", a.signature().name(), ", expected:\n", - a.DebugString()); + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -82,6 +82,24 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) +// TODO(misard): remove these fake registrations once there are real Ops to be +// compiled. +REGISTER_OP("_XlaSendToHost") + .Input("input: dtypes") + .Attr("dtypes: list(type) >= 0"); + +REGISTER_OP("_XlaRecvFromHost") + .Output("output: dtypes") + .Attr("dtypes: list(type) >= 0"); + +REGISTER_OP("_XlaSendFromHost") + .Input("input: dtypes") + .Attr("dtypes: list(type) >= 0"); + +REGISTER_OP("_XlaRecvAtHost") + .Output("output: dtypes") + .Attr("dtypes: list(type) >= 0"); + REGISTER_OP("InputTest").Output("o: float"); REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); @@ -98,10 +116,32 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); +Node* NoOp(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("NoOp", opts); +} + Node* Input(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTest", opts); } +Node* RecvAtHost(const gtl::ArraySlice& dtypes, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), + "_XlaRecvAtHost", opts.op_registry()); + return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); +} + +Node* SendFromHost(const std::vector& inputs, + const gtl::ArraySlice& dtypes, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), + "_XlaSendFromHost", opts.op_registry()); + node_builder.Input(inputs); + return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); +} + Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { return ops::UnaryOp("UnaryTest", std::move(a), opts); } @@ -145,7 +185,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { if (!s.ok()) return s; std::unique_ptr graph_out; - s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, + s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, /*reuse_existing_functions=*/false, @@ -178,6 +218,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) { FunctionDefLibrary library_out = library_in; TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out)); + // If there are no marked nodes, funcification should be a no-op. TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out); } @@ -230,7 +271,6 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - // If there are no marked nodes, funcification should be a no-op. TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -342,9 +382,9 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, - &library)); + "_cluster", "_outside", graph_before_encapsulation, + /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); @@ -374,9 +414,9 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, - &library)); + "_cluster", "_outside", graph_before_encapsulation, + /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, + /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = { "add1", "add2", "cluster1", "cluster1_parallel_check/_0", @@ -432,7 +472,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); int guaranteed_consts = 0; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_encapsulate", graph_before, + "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ [&guaranteed_consts](std::unique_ptr* graph_ptr, std::vector* input_permutation, @@ -477,7 +517,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); int guaranteed_consts = 0; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_encapsulate", graph_before, + "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ [&guaranteed_consts](std::unique_ptr* graph_ptr, std::vector* input_permutation, @@ -502,5 +542,678 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { EXPECT_EQ(1, guaranteed_consts); } +// Test with one function to transform and one outside_compilation cluster. +TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {}, + {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"C:o:0", "c:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"c"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost({e}, {DT_FLOAT}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one function to transform and two outside_compilation clusters. +TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = Binary(e, f, + b1.opts() + .WithName("G") + .WithControlInputs({e, f}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* h = Binary(d, e, + b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1")); + Binary(g, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, + {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {}, + {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O2_send"}, + "_XlaSendToHost", + {"D:o:0", "F:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"F"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"C:o:0", "D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"D"}}, + {{"outside_compilation_O2_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O2_send"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Node* recv1 = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts().WithName("E").WithControlInputs({recv1, b})); + Node* send1 = SendFromHost({e}, {DT_FLOAT}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* recv2 = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* g = Binary(e, ops::NodeOut(recv2, 1), + b2.opts().WithName("G").WithControlInputs({recv2, e})); + Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); + Node* send2 = SendFromHost( + {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send")); + + Node* s = NoOp(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2})); + + Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two functions to transform, each with one outside_compilation +// cluster. +TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = Binary(e, f, + b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr( + "_encapsulate", "F2")); + Node* h = Binary(d, g, + b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + Node* i = + Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); + Binary(g, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"f_0_retval:float", "d_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {}, + {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"C:o:0", "D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"D"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F2", {"e_0_arg:float", "f_0_arg:float"}, + {"g_0_retval:float", "i_0_retval:float"}, {}, + { + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, + {{"I"}, + "BinaryTest", + {"f_0_arg", "outside_compilation_O1_recv:output:0"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"G:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts().WithName("E").WithControlInputs({recv1, b})); + Node* send1 = SendFromHost({e}, {DT_FLOAT}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Node* recv2 = RecvAtHost( + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* h = Binary(ops::NodeOut(call1, 1), recv2, + b2.opts().WithName("H").WithControlInput(s1)); + Node* send2 = SendFromHost( + {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send")); + + NodeBuilder node_builder2("F2", "F2", lib_def.get()); + node_builder2.Input(e).Input(call1); + Node* call2 = b2.opts() + .WithControlInputs({s1, e, call1}) + .FinalizeBuilder(&node_builder2); + Node* s2 = NoOp( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); + Binary(call2, ops::NodeOut(call2, 1), + b2.opts().WithName("J").WithControlInput(s2)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no inputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = + Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Unary(f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"D:o:0", "outside_compilation_O1_recv:output:0"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts().WithName("E")); + Node* send1 = SendFromHost( + {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + + Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no data inputs but has a +// control input from the compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithControlInput(d) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = + Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Unary(f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"D:o:0", "outside_compilation_O1_recv:output:0"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {}, + {{"dtypes", gtl::ArraySlice({})}}, + {"D"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); + Node* send1 = SendFromHost( + {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no outputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = RecvAtHost( + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(recv1, b2.opts().WithName("E")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + + Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no data outputs but has a +// control output to the compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({})}}, + {"outside_compilation_O1_send"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = RecvAtHost( + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* send1 = SendFromHost({}, {}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no outputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts().WithName("E")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("G")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 1f311a3aedb..79b02baba83 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -41,6 +41,7 @@ limitations under the License. namespace tensorflow { const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; namespace { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index f91695800f5..e9acbfb19e4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -28,6 +28,10 @@ namespace tensorflow { // encapsulate subgraphs pass. extern const char* const kXlaClusterAttr; +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + // Pass that marks a subset of operators in the graph with attribute // _XlaCluster so they are compiled by the EncapsulateSubgraphsPass. class MarkForCompilationPass : public GraphOptimizationPass { diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index f9803be32f5..6435226fbe6 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -270,8 +270,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } -/* static */ tensorflow::gtl::ArraySlice -LayoutUtil::PaddedDimensions(const Shape& shape) { +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( + const Shape& shape) { CHECK(IsDense(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index d00cd037563..f73cc957649 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -96,7 +96,7 @@ class LayoutUtil { // Returns the padded_dimensions array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice PaddedDimensions( + static tensorflow::gtl::ArraySlice PaddedDimensions( const Shape& shape); // Returns the given index of the padded_dimensions array for the given Shape. diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index c782e0f19e5..6254fafaf3e 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -341,7 +341,7 @@ class Literal { // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + static std::unique_ptr CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value); // Creates a new literal from an array. The variants not ending with @@ -1233,10 +1233,9 @@ void Literal::PopulateWithValue(NativeT value, } template -/* static */ std::unique_ptr -Literal::CreateFullWithMonotonicDim0MajorLayout( +/* static */ std::unique_ptr Literal::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape this_shape = ShapeUtil::MakeShapeWithDescendingLayout( primitive_util::NativeToPrimitiveType(), dimensions); auto literal = MakeUnique(); *literal->mutable_shape() = this_shape; diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a6b8158671f..ece1f875f9f 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -46,6 +46,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 0b0a53fac7a..e2c45280f3c 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -166,6 +166,18 @@ ComputationDataHandle LocalComputationBuilder::Dot( return builder_.Dot(lhs, rhs); } +ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers); +} + ComputationDataHandle LocalComputationBuilder::ConvertElementType( const ComputationDataHandle& operand, PrimitiveType new_element_type) { return builder_.ConvertElementType(operand, new_element_type); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index cbab45a5f01..b3cbcc3f296 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -113,6 +114,14 @@ class LocalComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + ComputationDataHandle ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, PrimitiveType new_element_type); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index ac8f3e42777..980bad09c1b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,18 +22,19 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> long -// ArraySlice <- sequence of long -// ArraySlice <- sequence of long +// ComputationDataHandle <-> int +// ArraySlice <- sequence of int +// ArraySlice <- sequence of int // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape <-> pair holding (dtype, dimensions) // std::vector <- sequence of shape information pairs // PrimitiveType <- int +// ArraySlice> <- sequence of int pairs +// ConvolutionDimensionNumbers proto <- corresponding Python proto // // Arrows indicate whether a conversion only ever occurs in one -// direction, or whether it is maintained bidirectionally. Also, -// "long" and "int" denote the Python types so named, not C. +// direction, or whether it is maintained bidirectionally. // // The Python objects corresponding to C++ Literals have the type: // @@ -113,6 +114,27 @@ limitations under the License. using namespace xla; using namespace xla::swig; + +namespace xla { +namespace swig { + +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +} +} %} // Required to use PyArray_* functions. @@ -278,6 +300,189 @@ tensorflow::ImportNumpy(); $1 = static_cast(value); } +// ArraySlice> + +%typemap(in) tensorflow::gtl::ArraySlice > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + return NULL; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + return NULL; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + return NULL; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + return NULL; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + return NULL; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + return NULL; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + return NULL; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + return NULL; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + PyObject* o; + int length; + + o = PyObject_GetAttrString($input, "input_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_input_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_kernel_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "output_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_output_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + $1 = &dimension_numbers; +} + %ignoreall %unignore xla; %unignore xla::swig; @@ -314,6 +519,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Lt; %unignore xla::swig::LocalComputationBuilder::Le; %unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; %unignore xla::swig::LocalComputationBuilder::Add; %unignore xla::swig::LocalComputationBuilder::Sub; %unignore xla::swig::LocalComputationBuilder::Mul; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index c75d54856dd..36fbd0a1bcf 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum # pylint: disable=g-bad-import-order import itertools import numpy as np @@ -25,6 +26,12 @@ import numpy as np from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + _UNARY_OPS = [ 'Not', 'Abs', @@ -564,6 +571,79 @@ class ComputationBuilder(object): return _wrap_data_handle( self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + def Conv(self, lhs, rhs, window_strides, padding): + """Enqueues a Conv operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank N+2 array of inputs. + rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: PaddingType representing either 'SAME' or 'VALID' padding. + + Returns: a ComputationDataHandle representing the Conv operation. + """ + if padding == PaddingType.SAME: + lhs_dims = self.GetShape(lhs).dimensions() + rhs_dims = self.GetShape(rhs).dimensions() + in_shape, filter_shape = lhs_dims[2:], rhs_dims[2:] + out_shape = np.ceil(np.true_divide(in_shape, window_strides)).astype(int) + pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size + in zip(out_shape, window_strides, filter_shape, in_shape)] + pads = [(pad_size // 2, pad_size - pad_size // 2) + for pad_size in pad_sizes] + else: + pads = [(0, 0)] * len(window_strides) + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + return _wrap_data_handle( + self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), + _unwrap_data_handle(rhs), + window_strides, + pads, + (), + (), + dimension_numbers)) + + def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation): + """Enqueues a ConvWithGeneralPadding operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank N+2 array of inputs. + rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + window_strides: length-N array-like of kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of dilation factors. + rhs_dilation: length-N array-like of dilation factors. + + Returns: + A ComputationdataHandle representing the added ConvWithGeneralPadding op. + """ + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + return _wrap_data_handle( + self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), + _unwrap_data_handle(rhs), + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers)) + + def _GetConvDimensionNumbers(self, num_spatial_dims): + """Create ConvolutionDimensionNumbers proto for convolutions.""" + nd = num_spatial_dims + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + return dimension_numbers + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 878cd83edcc..bcbc5604288 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -386,6 +386,46 @@ class SingleOpTest(LocalComputationTest): c.Dot(c.Constant(lhs), c.Constant(rhs)) self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + def testConvF32Same(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 3, 4) + rhs = a(1, 2, 1, 2) * 10 + c.Conv(c.Constant(lhs), c.Constant(rhs), + [1, 1], xla_client.PaddingType.SAME) + result = np.array([[[[640., 700., 760., 300.], + [880., 940., 1000., 380.], + [1120., 1180., 1240., 460.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvF32Valid(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 3, 4) + rhs = a(1, 2, 1, 2) * 10 + c.Conv(c.Constant(lhs), c.Constant(rhs), + [2, 1], xla_client.PaddingType.VALID) + result = np.array([[[[640., 700., 760.], + [1120., 1180., 1240.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvWithGeneralPaddingF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 7dc09a8cbd2..6d7e2576eb3 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1182,6 +1182,11 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { + if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + return ReplaceWithNewInstruction( + pad, HloInstruction::CreateBroadcast(pad->shape(), + pad->mutable_operand(1), {})); + } // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; @@ -1624,6 +1629,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { + if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateBroadcast(reduce_window->shape(), + reduce_window->mutable_operand(1), {})); + } auto operand = reduce_window->mutable_operand(0); const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); @@ -1694,7 +1705,6 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); - if (std::is_sorted(transpose->dimensions().begin(), transpose->dimensions().end())) { VLOG(10) << "deleting no-op transpose"; @@ -1721,6 +1731,18 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); + if (ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(convolution->shape().element_type(), {}), + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + {})); + } const auto& window = convolution->window(); if (!enable_conv_simplification_) { return Status::OK(); @@ -1813,18 +1835,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // We already checked feature_dimension is most minor, so data in input_shape // and row-major {conv_width,input_channels} are bitwise identical. - const Shape new_input_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {conv_width, input_channels}); + const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {conv_width, input_channels}); // We already checked input_feature_dimension is more major than // output_feature_dimension, so data in filter_shape and row-major // {input_channels,output_channels} are bitwise identical. - const Shape new_filter_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - filter_shape.element_type(), {input_channels, output_channels}); - const Shape dot_output_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - convolution_shape.element_type(), {conv_width, output_channels}); + const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + filter_shape.element_type(), {input_channels, output_channels}); + const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + convolution_shape.element_type(), {conv_width, output_channels}); // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index d4739ca113a..175d4d8d7fb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -816,6 +816,120 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { 1); } +TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); + + HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs")); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.set_input_feature_dimension(2); + + dnums.set_output_batch_dimension(0); + dnums.add_output_spatial_dimensions(1); + dnums.set_output_feature_dimension(2); + + dnums.add_kernel_spatial_dimensions(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(2); + Window window; + WindowDimension* dim = window.add_dimensions(); + dim->set_size(3); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_window_reversal(false); + // Create add computation. + std::unique_ptr module = CreateNewModule(); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + module->AddEntryComputation(builder.Build()); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Convolution(lhs, rhs)); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); + Window window; + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(1); + dim->set_padding_high(1); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + // Create add computation. + std::unique_ptr module = CreateNewModule(); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module->AddEmbeddedComputation(builder.Build()); + } + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(F32, {5, 2}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + window, add_computation)); + module->AddEntryComputation(builder.Build()); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::ReduceWindow(param, op::Constant())); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); + PaddingConfig padding; + for (int i = 0; i < 2; ++i) { + PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions(); + dimension->set_edge_padding_low(1); + dimension->set_edge_padding_high(1); + dimension->set_interior_padding(0); + } + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {5, 2}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + padding)); + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(param, op::Constant())); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1309,7 +1423,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}), + 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}), "param0")); HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index b69a6e730fc..4e80679c11d 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -28,8 +28,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { StatusOr AllocationTracker::Register( diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 7c25a2e6450..46c94d0bf1e 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -96,6 +96,26 @@ class ConvolutionThunk : public Thunk { return !best_algorithm_.has_value(); } + // Return true if scratch memory is needed to execute the thunk, that is + // either the best algorithm hasn't been chosen or the best algorithm is not + // the same as the no-scratch algorithm. This is because that the execution + // of the thunk is asynchronous, and the scratch allocator goes out of + // scope before the thunk finishes execution. Returning true tells the stream + // executor to make future thunks wait for this thunk to avoid reusing the + // deallocated scratch memory until this thunk is done with it. + bool ShouldBlockFutureThunks() { + if (!best_algorithm_.has_value()) { + return true; + } + + const perftools::gputools::dnn::AlgorithmDesc& best_alg = + best_algorithm_->algorithm(); + const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg = + best_algorithm_->algorithm_no_scratch(); + return (!best_alg.is_default() || !no_scratch_best_alg.is_default() || + !(best_alg == no_scratch_best_alg)); + } + private: tensorflow::Status ConvolveWithTune( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 366d87e9c30..df6c6668c33 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -155,6 +155,7 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + std::map last_blocking_thunk_for_stream; std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { TF_RETURN_IF_ERROR(thunk->Initialize(*this)); @@ -167,10 +168,17 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + if (last_blocking_thunk_for_stream.count(stream_no)) { + stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, + last_blocking_thunk_for_stream[stream_no]) + .get()); + } + // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); + last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -178,11 +186,14 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk)) { + if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); + if (thunk->ShouldBlockFutureThunks()) { + last_blocking_thunk_for_stream[stream_no] = thunk; + } } profiler.FinishOperation(thunk->hlo_instruction()); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 022c63de8db..75021280cb7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -419,8 +419,8 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, (segs.size() == i ? shape.dimensions().size() : segs[i]), 1, std::multiplies())); } - return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), - dimensions); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + dimensions); } // Returns whether the given shapes and permutation are a 0-2-1 transpose, and @@ -442,11 +442,13 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { } } auto segs = ConsecutiveSegments(perm); - Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a); - Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b); + Shape norm_a = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); + Shape norm_b = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); if (3 == segs.size() && 0 == perm[0]) { Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( b.element_type(), Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); return std::make_tuple(true, reduced_a, reduced_b); @@ -460,10 +462,11 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { bool AreShapesForTranspose021(const Shape& a, const Shape& b) { return 3 == b.dimensions().size() && ShapeUtil::Compatible( - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a), + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), ShapeUtil::PermuteDimensions( {0, 2, 1}, - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b))); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + b))); } // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from @@ -495,9 +498,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); Shape input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input.GetShape()); Shape output_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + output.GetShape()); input = input.CastToShape(input_shape, builder); output = output.CastToShape(output_shape, builder); @@ -615,7 +620,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder))), builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + ShapeUtil::MakeShapeWithDescendingLayout( PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), builder); const llvm_ir::IrArray::Index input_tile_origin = ({ @@ -811,14 +816,15 @@ Status IrEmitterUnnested::EmitColumnReduction( // input_shape to normalized_input_shape and a reshape from // normalized_input_shape to input_matrix_shape. const Shape normalized_input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {height, width}); + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {height, width}); const llvm_ir::IrArray::Index input_matrix_index( {y, x}, input_matrix_shape, &ir_builder_); const llvm_ir::IrArray::Index input_index = @@ -1054,13 +1060,14 @@ Status IrEmitterUnnested::EmitRowReduction( // from input_shape to normalized_input_shape and a reshape from // normalized_input_shape to input_3d_tensor_shape. const Shape normalized_input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {depth, height, width}); + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {depth, height, width}); const llvm_ir::IrArray::Index input_3d_tensor_index( {z, y, x}, input_3d_tensor_shape, &ir_builder_); const llvm_ir::IrArray::Index input_index = diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 486ea7d7e1d..1b1d03d46ce 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -83,6 +83,16 @@ class Thunk { return false; } + // Indicates whether thunks scheduled after this one should wait for this one + // to complete before running. For example, a convolution thunk creates a + // scratch allocator, then kicks off a convolution in cudnn via the stream + // executor. When the stream executor call returns, the scratch allocator goes + // out of scope, and the scratch memory is deallocated. In this case, the + // convolution thunk needs to return true so that future thunks wait for the + // convolution thunk to avoid reusing the deallocated memory until the + // convolution thunk is done with it. + virtual bool ShouldBlockFutureThunks() { return false; } + // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 97697d06b73..97765d65909 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1361,7 +1361,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); std::unique_ptr arg_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); @@ -1414,7 +1414,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 8.0f); + Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); LiteralTestUtil::ExpectEqual(*result_literal, *result); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 42bca3b783c..e40f7cae33b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -478,7 +478,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } else if (instruction->opcode() == HloOpcode::kCustomCall) { // Add constraints for kCustomCall instruction operands and instructions. // For now we only support major-first layouts for all inputs and outputs. - Shape result_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( instruction->shape().element_type(), AsInt64Slice(instruction->shape().dimensions())); TF_RETURN_IF_ERROR( @@ -491,7 +491,7 @@ Status LayoutAssignment::AddMandatoryConstraints( } Shape row_major_operand_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + ShapeUtil::MakeShapeWithDescendingLayout( operand_shape.element_type(), AsInt64Slice(operand_shape.dimensions())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ead9f5c4ce7..48b7515ecff 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -189,20 +189,21 @@ StatusOr MakeShapeWithLayoutInternal( .ValueOrDie(); } -/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( +/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } -/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout( +/* static */ Shape +ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { std::vector dims(shape.dimensions_size()); for (int i = 0; i < shape.dimensions_size(); ++i) { dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); + return MakeShapeWithDescendingLayout(shape.element_type(), dims); } /* static */ void ShapeUtil::PopulateShape( @@ -1148,9 +1149,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // as input_shape/output_shape and the dimension-0-major layout. These two // shapes are used for conversion between logical linear indices and // multi-dimensional indices. - Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + Shape input_shape_dim0_major = MakeShapeWithDescendingLayout( input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); - Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 301247d61c5..a2043eff1ed 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -268,14 +268,18 @@ class ShapeUtil { PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major); - // Constructs a new shape with major-first layout. - static Shape MakeShapeWithMonotonicDim0MajorLayout( + // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). + static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions); - // Returns a new shape with major-first layout that has the same layout of - // elements with a different shape. - static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + // Returns a new Shape based on the given Shape with low-dimension-major + // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions + // rearranged so that it has the same in-memory layout as the given shape. + // + // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}. + static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + const Shape& shape); // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index bf81514bc90..e857d5b50db 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -393,7 +393,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { auto shape = ShapeUtil::MakeShape(F32, input_dims); std::unique_ptr arg_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -402,7 +402,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = - Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); + Literal::CreateFullWithDescendingLayout(output_dims, 9.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index d5ad1453278..5ae6188dfda 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -397,6 +397,7 @@ py_test( srcs = ["scan_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -491,6 +492,25 @@ py_test( ], ) +py_test( + name = "unique_dataset_op_test", + size = "small", + srcs = ["unique_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//third_party/py/numpy", + ], +) + py_test( name = "zip_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 506eefbef02..54f6974dba2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -735,6 +735,20 @@ class BatchDatasetSerializationTest( lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), num_outputs) + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + def testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 5338ec56bf2..e0494736b72 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,6 +21,7 @@ import itertools import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) +class ScanDatasetSerialzationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py new file mode 100644 index 00000000000..55296d5710e --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class UniqueDatasetTest(test.TestCase): + + def _testSimpleHelper(self, dtype, test_cases): + """Test the `unique()` transformation on a list of test cases. + + Args: + dtype: The `dtype` of the elements in each test case. + test_cases: A list of pairs of lists. The first component is the test + input that will be passed to the transformation; the second component + is the expected sequence of outputs from the transformation. + """ + + # The `current_test_case` will be updated when we loop over `test_cases` + # below; declare it here so that the generator can capture it once. + current_test_case = [] + dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case, + dtype).apply(unique.unique()) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for test_case, expected in test_cases: + current_test_case = test_case + sess.run(iterator.initializer) + for element in expected: + if dtype == dtypes.string: + element = compat.as_bytes(element) + self.assertAllEqual(element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSimpleInt(self): + for dtype in [dtypes.int32, dtypes.int64]: + self._testSimpleHelper(dtype, [ + ([], []), + ([1], [1]), + ([1, 1, 1, 1, 1, 1, 1], [1]), + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]), + ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]), + ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), + ]) + + def testSimpleString(self): + self._testSimpleHelper(dtypes.string, [ + ([], []), + (["hello"], ["hello"]), + (["hello", "hello", "hello"], ["hello"]), + (["hello", "world"], ["hello", "world"]), + (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]), + ]) + + +class UniqueSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1f35ee056b7..00af1f0b8ed 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -105,6 +105,7 @@ py_library( "resampling.py", "scan_ops.py", "stats_ops.py", + "unique.py", ], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py new file mode 100644 index 00000000000..133e17d20d0 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unique element dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_dataset_ops + + +def unique(): + """Creates a `Dataset` from another `Dataset`, discarding duplicates. + + Use this transformation to produce a dataset that contains one instance of + each unique element in the input. For example: + + ```python + dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) + + # Using `unique()` will drop the duplicate elements. + dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 } + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return UniqueDataset(dataset) + + return _apply_fn + + +class UniqueDataset(dataset_ops.Dataset): + """A `Dataset` contains the unique elements from its input.""" + + def __init__(self, input_dataset): + """See `unique()` for details.""" + super(UniqueDataset, self).__init__() + self._input_dataset = input_dataset + if input_dataset.output_types not in (dtypes.int32, dtypes.int64, + dtypes.string): + raise TypeError( + "`tf.contrib.data.unique()` only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.") + + def _as_variant_tensor(self): + return gen_dataset_ops.unique_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 95fba59e3c9..4928bf2c10e 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -110,6 +110,7 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index d255a6e7160..c8631ed89ba 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -22,11 +22,14 @@ import numpy as np import numpy.random as npr from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -95,6 +98,18 @@ class SubGraphTest(test.TestCase): filtered_list = sub_graph.filter_list(input_list) self.assertEqual(filtered_list, [b]) + def testVariableUses(self): + with ops.Graph().as_default(): + var = variable_scope.get_variable('var', shape=[10, 10]) + resource_var = variable_scope.get_variable( + 'resource_var', shape=[10, 10], use_resource=True) + x = array_ops.zeros([3, 10]) + z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) + z1 = math_ops.matmul(x, resource_var) + sub_graph = utils.SubGraph((z0, z1)) + self.assertEqual(2, sub_graph.variable_uses(var)) + self.assertEqual(1, sub_graph.variable_uses(resource_var)) + class UtilsTest(test.TestCase): @@ -253,6 +268,25 @@ class UtilsTest(test.TestCase): np_inv = np.linalg.inv(x + damp * np.eye(size)) self.assertAllClose(sess.run(tf_inv), np_inv) + def testCrossReplicaMean(self): + """Ensures that cross_replica_mean() executes only when num_shards > 1.""" + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(4): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertNotEqual(mean, tensor) + + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(1): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertEqual(mean, tensor) + + with ops.Graph().as_default(): + with self.assertRaises(ValueError): # Outside of TPU context. + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index 3d731c7bc20..9be3d60dc06 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -196,6 +196,7 @@ py_library( srcs = ["utils.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 5a6d1a93ff2..e4e81cd13de 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -267,6 +267,10 @@ class FisherFactor(object): new_cov = math_ops.add_n( tuple(self._compute_new_cov(idx) for idx in range(self._num_sources))) + # Synchronize value across all TPU cores. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index cec018e406b..48b191ef501 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -27,6 +29,8 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" @@ -226,11 +230,13 @@ class SubGraph(object): """ def __init__(self, outputs): + # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() self._recurse_add(outputs) def _recurse_add(self, nodes): + """Recursively adds all of nodes' ancestors.""" for node in nodes: if node in self._members: continue @@ -246,8 +252,25 @@ class SubGraph(object): return node in self._members def variable_uses(self, var): - """Computes number of times a variable is used.""" - return len(self._members.intersection(set(var.value().consumers()))) + """Computes number of times a variable is used. + + Args: + var: Variable or ResourceVariable instance. + + Returns: + Number of times a variable is used within this subgraph. + + Raises: + ValueError: If 'var' is not a variable type. + """ + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var.handle + elif isinstance(var, variables.Variable): + var = var.value() + else: + raise ValueError("%s does not appear to be a variable." % str(var)) + + return len(self._members.intersection(set(var.consumers()))) def filter_list(self, node_list): """Filters 'node_list' to nodes in this subgraph.""" @@ -292,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): return dysdx + +def on_tpu(): + """Returns True when building a TPU computation.""" + return tpu_function.get_tpu_context().number_of_shards is not None + + +def cross_replica_mean(tensor, name=None): + """Takes mean value of a Tensor across all TPU cores. + + Args: + tensor: Tensor to be synchronized. + name: None or string. Name of Op. + + Returns: + Average of Tensor across all TPU cores. + + Raises: + ValueError: If called outside of TPU context. + """ + with ops.name_scope(name, "cross_replica_mean", [tensor]): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + raise ValueError( + "Cannot take cross_replica_mean() outside of TPU Context.") + if num_shards == 1: + return tensor + return tpu_ops.cross_replica_sum(tensor / num_shards) + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc index 86033275a0f..9835f86398a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -179,6 +179,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { ConcatenateTensorBuffers( input_arrays, concatenation_axis, &concatenated_array); break; + case ArrayDataType::kString: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; default: LOG(FATAL) << "ArrayDataType not supported"; } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 31eee12ffca..5d297bb11e2 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -52,6 +52,7 @@ using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; using tensorflow::DT_INT32; using tensorflow::DT_INT64; +using tensorflow::DT_STRING; using tensorflow::DT_UINT8; using tensorflow::GraphDef; using tensorflow::NodeDef; @@ -135,6 +136,8 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kInt32; else if (dtype == DT_INT64) return ArrayDataType::kInt64; + else if (dtype == DT_STRING) + return ArrayDataType::kString; else LOG(INFO) << "Unsupported data type in placehoder op: " << dtype; return ArrayDataType::kNone; @@ -236,6 +239,27 @@ void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { } } +void ImportStringArray(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_STRING); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_string_data = + output_array->GetMutableBuffer().data; + output_string_data.resize(input_flat_size); + if (input_flat_size != input_tensor.string_val_size()) { + LOG(FATAL) << "Input_content string_val doesn't have the right " + "dimensions for this string tensor."; + } + for (int i = 0; i < input_flat_size; ++i) { + output_string_data[i] = input_tensor.string_val(i); + } +} + // Count the number of inputs of a given node. If // `tf_import_flags.drop_control_dependency` is true, count the number of // non-control-dependency inputs. @@ -261,23 +285,30 @@ void ConvertConstOperator(const NodeDef& node, const auto dtype = GetDataTypeAttr(node, "dtype"); auto& array = model->GetOrCreateArray(node.name()); - array.data_type = dtype == DT_FLOAT - ? ArrayDataType::kFloat - : dtype == DT_INT32 - ? ArrayDataType::kInt32 - : dtype == DT_INT64 ? ArrayDataType::kInt64 - : ArrayDataType::kNone; - if (dtype == DT_FLOAT) { - ImportFloatArray(tensor, &array); - } else if (dtype == DT_INT32) { - ImportInt32Array(tensor, &array); - } else if (dtype == DT_INT64) { - ImportInt64Array(tensor, &array); - } else { - // do nothing, silently ignore the Const data. For example, there are consts - // of string type. We just make a dummy buffer to indicate that this array - // does not rely on external input. - array.GetMutableBuffer(); + switch (dtype) { + case DT_FLOAT: + array.data_type = ArrayDataType::kFloat; + ImportFloatArray(tensor, &array); + break; + case DT_INT32: + array.data_type = ArrayDataType::kInt32; + ImportInt32Array(tensor, &array); + break; + case DT_INT64: + array.data_type = ArrayDataType::kInt64; + ImportInt64Array(tensor, &array); + break; + case DT_STRING: + array.data_type = ArrayDataType::kString; + ImportStringArray(tensor, &array); + break; + default: + array.data_type = ArrayDataType::kNone; + // do nothing, silently ignore the Const data. + // We just make a dummy buffer to indicate that + // this array does not rely on external input. + array.GetMutableBuffer(); + break; } } @@ -1191,7 +1222,7 @@ void ConvertGatherOperator(const NodeDef& node, CHECK_EQ(node.op(), "Gather"); CHECK_EQ(GetInputsCount(node, tf_import_flags), 2); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); - CHECK(indices_data_type == DT_INT32); + CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 253b163649f..5e7cc294967 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -153,7 +153,15 @@ enum class AxesOrder { // because we'll be dropping the array anyway (e.g. some exotic array types // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). -enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 }; +enum class ArrayDataType { + kNone, + kBool, + kFloat, + kUint8, + kInt32, + kInt64, + kString +}; // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type template @@ -182,6 +190,10 @@ template <> struct DataTypeImpl { typedef int64 Type; }; +template <> +struct DataTypeImpl { + typedef string Type; +}; template using DataType = typename DataTypeImpl::Type; diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 5b4dbfae247..a6fa0237bc0 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -53,6 +53,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { return ::tflite::TensorType_INT32; case ArrayDataType::kUint8: return ::tflite::TensorType_UINT8; + case ArrayDataType::kString: + return ::tflite::TensorType_STRING; default: // FLOAT32 is filled for unknown data types. // TODO(ycling): Implement type inference in TF Lite interpreter. @@ -66,6 +68,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { return ArrayDataType::kFloat; case ::tflite::TensorType_INT32: return ArrayDataType::kInt32; + case ::tflite::TensorType_STRING: + return ArrayDataType::kString; case ::tflite::TensorType_UINT8: return ArrayDataType::kUint8; default: @@ -82,6 +86,8 @@ flatbuffers::Offset> DataBuffer::Serialize( return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); + case ArrayDataType::kString: + return CopyBuffer(array, builder); case ArrayDataType::kUint8: return CopyBuffer(array, builder); default: @@ -99,6 +105,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, return CopyBuffer(buffer, array); case ::tflite::TensorType_INT32: return CopyBuffer(buffer, array); + case ::tflite::TensorType_STRING: + return CopyBuffer(buffer, array); case ::tflite::TensorType_UINT8: return CopyBuffer(buffer, array); default: diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 381168d15a5..fd56f64cc9b 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -316,6 +316,9 @@ void LogArray(int log_level, const Model& model, const string& name) { case ArrayDataType::kUint8: VLOG(log_level) << " Data type: kUint8"; break; + case ArrayDataType::kString: + VLOG(log_level) << " Data type: kString"; + break; default: VLOG(log_level) << " Data type: other (numerical value: " << static_cast(array.data_type) << ")"; @@ -334,6 +337,9 @@ void LogArray(int log_level, const Model& model, const string& name) { case ArrayDataType::kUint8: VLOG(log_level) << " Final type: kUint8"; break; + case ArrayDataType::kString: + VLOG(log_level) << " Final type: kString"; + break; default: VLOG(log_level) << " Final type: other (numerical value: " << static_cast(array.data_type) << ")"; @@ -1253,6 +1259,11 @@ int ElementSize(ArrayDataType data_type) { return 1; case ArrayDataType::kInt64: return 8; + // Usually not critical limitation because strings are only input and/or + // output. + case ArrayDataType::kString: + LOG(FATAL) << "Transient arrays with strings are not supported yet"; + return 0; default: LOG(FATAL) << "Should not get here."; return 0; diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 20be819e07d..245fe07f2bc 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -82,22 +82,6 @@ py_test( ], ) -py_test( - name = "utils_test", - size = "small", - srcs = ["python/saved_model/utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":saved_model_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:variables", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:tag_constants", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/core/api_def/base_api/api_def_UniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_UniqueDataset.pbtxt new file mode 100644 index 00000000000..00925691690 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UniqueDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "UniqueDataset" + summary: "Creates a dataset that contains the unique elements of `input_dataset`." +} diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e4c1da52ec0..8eeaf31b26c 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -35,25 +35,41 @@ bool IsAdd(const NodeDef& node) { bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; } +bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; } + bool IsAnyDiv(const NodeDef& node) { return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "FloorDiv" || node.op() == "TruncateDiv"; } +bool IsApproximateEqual(const NodeDef& node) { + return node.op() == "ApproximateEqual"; +} + bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; } bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; } +bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; } + bool IsBiasAdd(const NodeDef& node) { return node.op() == "BiasAdd" || node.op() == "BiasAddV1"; } bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; } +bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; } + +bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } + +bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } + bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; } bool IsConstant(const NodeDef& node) { return node.op() == "Const"; } +bool IsConj(const NodeDef& node) { return node.op() == "Conj"; } + bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } bool IsConv2DBackpropFilter(const NodeDef& node) { @@ -92,39 +108,77 @@ bool IsEnter(const NodeDef& node) { return op == "Enter" || op == "RefEnter"; } +bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; } + bool IsExit(const NodeDef& node) { const auto& op = node.op(); return op == "Exit" || op == "RefExit"; } +bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; } + bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; } -bool IsFusedBatchNormGradV1(const NodeDef& node) { - return node.op() == "FusedBatchNormGrad"; +bool IsFusedBatchNormGrad(const NodeDef& node) { + const auto& op = node.op(); + return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2"; } +bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; } + +bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; } + bool IsIdentity(const NodeDef& node) { const auto& op = node.op(); return op == "Identity" || op == "RefIdentity"; } +bool IsIdentityN(const NodeDef& node) { + const auto& op = node.op(); + return op == "IdentityN"; +} + +bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; } + +bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; } + +bool IsImag(const NodeDef& node) { return node.op() == "Imag"; } + bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; } +bool IsLess(const NodeDef& node) { return node.op() == "Less"; } + +bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; } + +bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; } + +bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; } + +bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; } + bool IsMatMul(const NodeDef& node) { const auto& op = node.op(); return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" || op == "SparseMatMul"; } +bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; } + bool IsMerge(const NodeDef& node) { const auto& op = node.op(); return op == "Merge" || op == "RefMerge"; } +bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; } + +bool IsMod(const NodeDef& node) { return node.op() == "Mod"; } + bool IsMul(const NodeDef& node) { return node.op() == "Mul"; } bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; } +bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; } + bool IsNextIteration(const NodeDef& node) { const auto& op = node.op(); return op == "NextIteration" || op == "RefNextIteration"; @@ -138,6 +192,12 @@ bool IsPlaceholder(const NodeDef& node) { op == "PlaceholderWithDefault"; } +bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; } + +bool IsPow(const NodeDef& node) { return node.op() == "Pow"; } + +bool IsReal(const NodeDef& node) { return node.op() == "Real"; } + bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } bool IsReciprocalGrad(const NodeDef& node) { @@ -209,12 +269,18 @@ bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } +bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; } + +bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; } + bool IsVariable(const NodeDef& node) { const auto& op = node.op(); return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || op == "VarHandleOp" || op == "ReadVariableOp"; } +bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; } + namespace { bool GetBoolAttr(const NodeDef& node, const string& name) { return node.attr().count(name) > 0 && node.attr().at(name).b(); @@ -284,5 +350,10 @@ bool IsValuePreserving(const NodeDef& node) { return value_preserving_ops.count(node.op()) > 0; } +bool HasOpDef(const NodeDef& node) { + const OpDef* op_def = nullptr; + return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok(); +} + } // namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 0e246a661f0..7d5d1149f7d 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -24,11 +24,18 @@ namespace grappler { bool IsAdd(const NodeDef& node); bool IsAddN(const NodeDef& node); +bool IsAngle(const NodeDef& node); bool IsAnyDiv(const NodeDef& node); +bool IsApproximateEqual(const NodeDef& node); bool IsAvgPoolGrad(const NodeDef& node); bool IsAssert(const NodeDef& node); +bool IsAtan2(const NodeDef& node); bool IsBiasAdd(const NodeDef& node); bool IsBiasAddGrad(const NodeDef& node); +bool IsBitcast(const NodeDef& node); +bool IsComplex(const NodeDef& node); +bool IsComplexAbs(const NodeDef& node); +bool IsConj(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConstant(const NodeDef& node); bool IsConv2D(const NodeDef& node); @@ -41,18 +48,38 @@ bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); +bool IsEqual(const NodeDef& node); bool IsExit(const NodeDef& node); +bool IsFloorDiv(const NodeDef& node); bool IsFloorMod(const NodeDef& node); -bool IsFusedBatchNormGradV1(const NodeDef& node); +bool IsFusedBatchNormGrad(const NodeDef& node); +bool IsGreater(const NodeDef& node); +bool IsGreaterEqual(const NodeDef& node); bool IsIdentity(const NodeDef& node); +bool IsIdentityN(const NodeDef& node); +bool IsIgamma(const NodeDef& node); +bool IsIgammac(const NodeDef& node); +bool IsImag(const NodeDef& node); bool IsInvGrad(const NodeDef& node); +bool IsLess(const NodeDef& node); +bool IsLessEqual(const NodeDef& node); +bool IsLogicalAnd(const NodeDef& node); +bool IsLogicalNot(const NodeDef& node); +bool IsLogicalOr(const NodeDef& node); +bool IsMaximum(const NodeDef& node); bool IsMerge(const NodeDef& node); +bool IsMinimum(const NodeDef& node); +bool IsMod(const NodeDef& node); bool IsMul(const NodeDef& node); bool IsMatMul(const NodeDef& node); bool IsNextIteration(const NodeDef& node); bool IsPad(const NodeDef& node); bool IsNoOp(const NodeDef& node); +bool IsNotEqual(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); +bool IsPolygamma(const NodeDef& node); +bool IsPow(const NodeDef& node); +bool IsReal(const NodeDef& node); bool IsRealDiv(const NodeDef& node); bool IsRelu6Grad(const NodeDef& node); bool IsReluGrad(const NodeDef& node); @@ -80,7 +107,10 @@ bool IsSum(const NodeDef& node); bool IsSwitch(const NodeDef& node); bool IsTanhGrad(const NodeDef& node); bool IsTranspose(const NodeDef& node); +bool IsTruncateDiv(const NodeDef& node); +bool IsTruncateMod(const NodeDef& node); bool IsVariable(const NodeDef& node); +bool IsZeta(const NodeDef& node); // Return true if the op is an aggregation (e.g. Add, AddN). // Returns false if it could not be determined to be so. @@ -102,6 +132,9 @@ bool IsInvolution(const NodeDef& node); // function returns true if the op commutes with all element-wise operations. bool IsValuePreserving(const NodeDef& node); +// Returns true if we can find an opdef corresponding to the op of the node. +bool HasOpDef(const NodeDef& node); + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 6cc50845b31..f41d954c4cb 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" +#include #include #include "tensorflow/core/framework/node_def.pb.h" @@ -350,15 +351,16 @@ Status DependencyOptimizer::TransitiveReduction() { num_nodes); for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { const NodeDef& node = optimized_graph_->node(node_idx); - if (ModifiesFrameInfo(node)) { - // Ignore nodes that modify frame info. + if (ModifiesFrameInfo(node) || !HasOpDef(node)) { + // Ignore function nodes and nodes that modify frame info. continue; } for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) { const string& input = node.input(input_slot); const NodeDef* input_node = node_map_->GetNode(input); - if (ModifiesFrameInfo(*input_node)) { - // Ignore edges from nodes that modify frame info. + if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) { + // Ignore edges from nodes that modify frame info and from Merge nodes, + // because we cannot know which of it's input paths executes. continue; } const int input_node_idx = node_to_idx_[input_node]; @@ -375,6 +377,14 @@ Status DependencyOptimizer::TransitiveReduction() { // of length > 1, we can drop that control dependency. int num_controls_removed = 0; std::vector longest_distance(num_nodes); + // Map from target_index -> set of (input_slot, source_index), representing + // the control edges to remove. We sort them in reverse order by input slot, + // such that when we swap them out so we don't clobber the + // node(target).input() repeated field. + typedef std::pair InputSlotAndSource; + std::unordered_map< + int, std::set>> + control_edges_to_remove; for (int source = 0; source < num_nodes; ++source) { int highest_control_target = -1; for (const auto& control_output : control_outputs[source]) { @@ -382,7 +392,7 @@ Status DependencyOptimizer::TransitiveReduction() { highest_control_target = control_output.first; } } - if (highest_control_target < source) { + if (highest_control_target <= source) { continue; } std::fill(longest_distance.begin() + source, @@ -391,7 +401,10 @@ Status DependencyOptimizer::TransitiveReduction() { for (int input : inputs[target]) { // If the input node is before source in the topo order, no path // source -> input -> target can exits and we can skip it. - if (input >= source) { + // Also only extend a path from the source itself or from nodes that + // have a path from source, indicated by longest_distance[input] > 0. + if (input == source || + (input > source && longest_distance[input] > 0)) { // If source -> input -> target is longer than the longest // path so far from source -> target, update the longest_distance. int candidate_longest_distance = longest_distance[input] + 1; @@ -402,25 +415,36 @@ Status DependencyOptimizer::TransitiveReduction() { } } - // If the longest path from the source to the target of a control dependency - // is longer than 1, there exists an alternate path, and we can eliminate - // the control dependency since it is redundant. + // If the longest path from source to target of a control dependency is + // longer than 1, there exists an alternate path, and we can eliminate the + // redundant direct control dependency. for (const auto& control_output : control_outputs[source]) { const int target = control_output.first; if (longest_distance[target] > 1) { const int input_slot = control_output.second; - // We modify the node inplace here. This is safe because there can - // only be one control edge from a given source to a given target. - const NodeDef& source_node = optimized_graph_->node(source); - NodeDef* target_node = optimized_graph_->mutable_node(target); - target_node->mutable_input()->SwapElements( - input_slot, target_node->input_size() - 1); - node_map_->RemoveOutput(source_node.name(), target_node->name()); - target_node->mutable_input()->RemoveLast(); - ++num_controls_removed; + control_edges_to_remove[target].emplace(input_slot, source); + VLOG(1) << "Removing edge from:\n" + << optimized_graph_->node(source).DebugString() << "\n\nto:\n\n" + << optimized_graph_->node(target).DebugString(); } } } + + for (const auto& it : control_edges_to_remove) { + const int target = it.first; + NodeDef* target_node = optimized_graph_->mutable_node(target); + for (const InputSlotAndSource& slot_and_source : it.second) { + const int input_slot = slot_and_source.first; + const int source = slot_and_source.second; + const NodeDef& source_node = optimized_graph_->node(source); + CHECK_LT(input_slot, target_node->input_size()); + target_node->mutable_input()->SwapElements(input_slot, + target_node->input_size() - 1); + node_map_->RemoveOutput(source_node.name(), target_node->name()); + target_node->mutable_input()->RemoveLast(); + ++num_controls_removed; + } + } VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls << " control dependencies"; return Status::OK(); @@ -442,36 +466,27 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, nodes_to_preserve_ = item.NodesToPreserve(); fetch_nodes_known_ = !item.fetch.empty(); - VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString(); CleanControlInputs(); - const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1; + const int num_iterations = 2; for (int iteration = 0; iteration < num_iterations; ++iteration) { Status topo_sort_status; - if (opt_level_ == RewriterConfig::AGGRESSIVE) { - // Prepare the graph for transitive reduction if enabled. - topo_sort_status = TopologicalSort(optimized_graph_); - } + // Perform topological sort to prepare the graph for transitive reduction. + topo_sort_status = TopologicalSort(optimized_graph_); + // Set up index-based graph datastructures to speed up analysis steps below. node_map_.reset(new NodeMap(optimized_graph_)); BuildNodeToIdx(); - // Remove redundant control dependencies, iteration 1. - if (opt_level_ == RewriterConfig::AGGRESSIVE) { - if (topo_sort_status.ok()) { - TF_RETURN_IF_ERROR(TransitiveReduction()); - } else { - LOG(ERROR) << topo_sort_status.error_message(); - } - VLOG(1) << "Graph after transitive reduction:\n" - << optimized_graph_->DebugString(); + if (topo_sort_status.ok()) { + // Remove redundant control dependencies. + TF_RETURN_IF_ERROR(TransitiveReduction()); + } else { + LOG(ERROR) << topo_sort_status.error_message(); } - // Turn nodes without non-control outputs into NoOps, prune NoOps. + // Turn nodes with only control outputs into NoOps, prune NoOps. TF_RETURN_IF_ERROR(OptimizeDependencies()); - VLOG(1) << "Graph after NoOp conversion & pruning:\n" - << optimized_graph_->DebugString(); } - VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString(); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index 837fbba2fc1..c6d81ee60d7 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -157,7 +157,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); + DependencyOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -228,6 +228,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) { // The optimization should be disabled to prevent increasing the number of // nodes crossing device boundaries. + TF_CHECK_OK(TopologicalSort(&item.graph)); VerifyGraphsEqual(item.graph, output, __FUNCTION__); } @@ -282,7 +283,7 @@ TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch.push_back("id2"); - DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); + DependencyOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index bcf785f2722..cf15a44b39c 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -60,7 +60,9 @@ std::set GetOpsFormatSupported() { "DepthwiseConv2dNativeBackpropInput", "DepthwiseConv2dNativeBackpropFilter", "FusedBatchNorm", + "FusedBatchNormV2", "FusedBatchNormGrad", + "FusedBatchNormGradV2", "FusedConv2DBiasActivation", "MaxPool", "MaxPoolGrad", @@ -75,52 +77,77 @@ std::set GetOpsFormatAgnostic() { std::set ops_format_agnostic = {"Abs", "Add", "AddN", + "AddV2", "Acos", "Acosh", "Angle", + "ApproximateEqual", "Asin", "Asinh", "Atan", + "Atan2", "Atanh", "Bitcast", "Cast", "Ceil", "CheckNumerics", - "Cos", - "Cosh", + "Complex", "ComplexAbs", "Concat", "ConcatV2", "Conj", + "Cos", + "Cosh", "Digamma", + "Div", "Elu", "EluGrad", + "Equal", "Erf", "Erfc", "Exp", "Expm1", "Floor", + "FloorDiv", + "FloorMod", + "Greater", + "GreaterEqual", "GuaranteeConst", "Identity", + "IdentityN", + "Igamma", + "Igammac", "Imag", "Inv", "InvGrad", "IsFinite", "IsInf", "IsNan", + "Less", + "LessEqual", "Lgamma", "Log", + "LogicalAnd", + "LogicalNot", + "LogicalOr", "Log1p", + "Maximum", "Merge", + "Minimum", + "Mod", "Mul", "Neg", + "NotEqual", "OnesLike", "Pad", "PreventGradient", + "Polygamma", + "Pow", "Real", "RealDiv", "Reciprocal", "ReciprocalGrad", + "RefIdentity", "Relu", "Relu6", "Relu6Grad", @@ -141,7 +168,8 @@ std::set GetOpsFormatAgnostic() { "SoftplusGrad", "Split", "Switch", - "RefIdentity", + "TruncateDiv", + "TruncateMod", "RefMerge", "RefSwitch", "Round", @@ -157,7 +185,8 @@ std::set GetOpsFormatAgnostic() { "Tan", "Tanh", "TanhGrad", - "ZerosLike"}; + "ZerosLike", + "Zeta"}; return ops_format_agnostic; } @@ -212,6 +241,28 @@ bool IsUnaryGrad(const NodeDef& node) { return is_unary_grad; } +bool IsComparisonOp(const NodeDef& node) { + bool is_compare = IsApproximateEqual(node) || IsEqual(node) || + IsGreater(node) || IsGreaterEqual(node) || IsLess(node) || + IsLessEqual(node) || IsNotEqual(node); + return is_compare; +} + +bool IsLogicalOp(const NodeDef& node) { + return IsLogicalAnd(node) || IsLogicalNot(node) || IsLogicalOr(node); +} + +bool IsBinaryOp(const NodeDef& node) { + bool is_binary = + IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) || + IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) || + IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) || + IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) || + IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) || + IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node); + return is_binary; +} + class GraphProcessor { public: GraphProcessor(const VirtualPlacer& virtual_placer, @@ -409,17 +460,19 @@ class NodeProcessor : public GraphProcessor { virtual void UpdateAttrShape() { if (node_->attr().find("_output_shapes") != node_->attr().end()) { - auto shape = node_->mutable_attr() - ->at("_output_shapes") - .mutable_list() - ->mutable_shape(0); - if (shape->dim_size() == 4) { - int64 h = shape->dim(1).size(); - int64 w = shape->dim(2).size(); - int64 c = shape->dim(3).size(); - shape->mutable_dim(1)->set_size(c); - shape->mutable_dim(2)->set_size(h); - shape->mutable_dim(3)->set_size(w); + for (const auto& pos : GetOutputPos()) { + auto shape = node_->mutable_attr() + ->at("_output_shapes") + .mutable_list() + ->mutable_shape(pos); + if (shape->dim_size() == 4) { + int64 h = shape->dim(1).size(); + int64 w = shape->dim(2).size(); + int64 c = shape->dim(3).size(); + shape->mutable_dim(1)->set_size(c); + shape->mutable_dim(2)->set_size(h); + shape->mutable_dim(3)->set_size(w); + } } } } @@ -603,13 +656,15 @@ class NodeProcessor : public GraphProcessor { added_node_name = AddPrefixToNodeName(added_node_base_name, kTransposeNCHWToNHWC, "-"); DataType dtype; - if (op == "Imag" || op == "Real" || op == "Angle" || - op == "Conj" || op == "ComplexAbs") { + if (IsAngle(*node_) || IsComplex(*node_) || + IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) { TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout")); dtype = node_->attr().at("Tout").type(); - } else if (op == "Bitcast") { + } else if (IsBitcast(*node_)) { TF_RETURN_IF_ERROR(HasAttribute(*node_, "type")); dtype = node_->attr().at("type").type(); + } else if (IsLogicalOp(*node_) || IsComparisonOp(*node_)) { + dtype = DT_BOOL; } else { TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); dtype = node_->attr().at("T").type(); @@ -617,7 +672,8 @@ class NodeProcessor : public GraphProcessor { TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); AddNodeTranspose( added_node_name, input, const_name, dtype, - node_->attr().at("_output_shapes").list().shape(0), false); + node_->attr().at("_output_shapes").list().shape(input_port), + false); } else if (op == "DataFormatVecPermute") { added_node_name = AddPrefixToNodeName(added_node_base_name, kVecPermuteNCHWToNHWC, "-"); @@ -1002,11 +1058,10 @@ class AgnosticNodeProcessor : public NodeProcessor { if (IsConcatV1(node)) { return {1}; } - if (IsAdd(node) || IsMul(node) || IsRealDiv(node) || - IsSquaredDifference(node) || IsSub(node)) { + if (IsBinaryOp(node) || IsUnaryGrad(node)) { return {0, 1}; } - if (IsShapeN(node)) { + if (IsShapeN(node) || IsIdentityN(node)) { std::vector pos; for (int i = 0; i < node.input_size(); i++) { pos.push_back(i); @@ -1207,6 +1262,40 @@ class ConcatProcessor : public AgnosticNodeProcessor { int axis_node_pos_; }; +class IdentityNProcessor : public AgnosticNodeProcessor { + public: + explicit IdentityNProcessor(const OptimizeContext& opt_cxt) + : AgnosticNodeProcessor(opt_cxt) {} + + protected: + bool ShouldProcess() const override { + return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsOnGPU(); + } + + std::vector GetInputPos() const override { + std::vector input_pos; + for (int i = 0; i < node_->input_size(); i++) { + auto input = node_map_->GetNode(node_->input(i)); + int port; + ParseNodeName(node_->input(i), &port); + if (IsPortDimsFour(*input, port) && + (IsNodeAfterNCHWToNHWC(*input) || IsNodeNCHWToNHWC(input->name()))) { + input_pos.push_back(i); + } + } + return input_pos; + } + + std::set GetOutputPos() const override { + std::set output_pos{}; + for (const auto& input_pos : GetInputPos()) { + output_pos.insert(input_pos); + } + return output_pos; + } +}; + class MergeProcessor : public AgnosticNodeProcessor { public: explicit MergeProcessor(const OptimizeContext& opt_cxt) @@ -1371,12 +1460,15 @@ class SqueezeProcessor : public AgnosticNodeProcessor { Status AddLayoutTransposeToOutputs() override { return Status::OK(); } bool IsInputConvertible() const { + int input_port; auto input = node_map_->GetNode(node_->input(0)); + ParseNodeName(node_->input(0), &input_port); if (IsNodeNCHWToNHWC(input->name())) { input = node_map_->GetNode(input->input(0)); + ParseNodeName(input->input(0), &input_port); } if (input->attr().find("_output_shapes") != input->attr().end()) { - auto shape = input->attr().at("_output_shapes").list().shape(0); + auto shape = input->attr().at("_output_shapes").list().shape(input_port); if (shape.dim_size() != 4) { return false; } @@ -1529,7 +1621,7 @@ class DataLayoutOptimizer : GraphProcessor { new Conv2DBackpropFilterProcessor(opt_cxt, true)); } else if (IsDepthwiseConv2dNativeBackpropInput(*node)) { node_processor.reset(new Conv2DBackpropInputProcessor(opt_cxt, true)); - } else if (IsFusedBatchNormGradV1(*node)) { + } else if (IsFusedBatchNormGrad(*node)) { node_processor.reset(new FusedBatchNormGradProcessor(opt_cxt)); } else if (IsMaxPoolGradV1(*node)) { node_processor.reset(new MaxPoolGradProcessor(opt_cxt)); @@ -1557,11 +1649,12 @@ class DataLayoutOptimizer : GraphProcessor { std::unique_ptr node_processor; if (IsAddN(*node)) { node_processor.reset(new AddNProcessor(opt_cxt)); - } else if (IsAdd(*node) || IsMul(*node) || IsRealDiv(*node) || - IsSquaredDifference(*node) || IsSub(*node)) { + } else if (IsBinaryOp(*node)) { node_processor.reset(new BinaryOpProcessor(opt_cxt)); } else if (IsConcat(*node)) { node_processor.reset(new ConcatProcessor(opt_cxt)); + } else if (IsIdentityN(*node)) { + node_processor.reset(new IdentityNProcessor(opt_cxt)); } else if (IsMerge(*node)) { node_processor.reset(new MergeProcessor(opt_cxt)); } else if (IsPad(*node)) { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 98109f724ee..781f942532d 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -1066,6 +1066,48 @@ TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) { "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-0-1"); } +TEST_F(LayoutOptimizerTest, Complex) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 4, 2, "VALID"); + auto comp = ops::Complex(s.WithOpName("complex"), conv, conv); + auto i = ops::Identity(s.WithOpName("i"), comp); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto merge_node = node_map.GetNode("complex"); + EXPECT_EQ(merge_node->input(0), "Conv2D"); + EXPECT_EQ(merge_node->input(1), "Conv2D"); + auto trans = + node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-complex-0-0"); + EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64); +} + +TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 4, 2, "VALID"); + auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2}); + auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv}); + auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto i = node_map.GetNode("identity_n"); + EXPECT_EQ(i->input(0), "vector"); + EXPECT_EQ(i->input(1), "Conv2D"); + auto trans = + node_map.GetNode("LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1"); + EXPECT_EQ(trans->input(0), "identity_n:1"); + auto add_node = node_map.GetNode("add"); + EXPECT_EQ(add_node->input(0), "identity_n"); + EXPECT_EQ(add_node->input(1), + "LayoutOptimizerTransposeNCHWToNHWC-identity_n-0-1"); +} } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index a83671a471c..6cec032f949 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -314,6 +314,11 @@ Status CudaSolver::forward_input_or_allocate_scoped_tensor( // are sometimes inaccurate, e.g., are missing 'const' on pointers // to immutable arguments, while the actual headers have them as expected. // Check the actual declarations in the cusolver_api.h header file. +// +// NOTE: The cuSolver functions called below appear not to be threadsafe. +// so we put a global lock around the calls. Since these functions only put a +// kernel on the shared stream, it is not a big performance hit. +// TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9. //============================================================================= template @@ -324,6 +329,7 @@ static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle, const Scalar* A, int lda, const Scalar* beta, /* host or device pointer */ const Scalar* B, int ldb, Scalar* C, int ldc) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT::type; TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n, reinterpret_cast(alpha), @@ -355,6 +361,7 @@ static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n, Scalar* A, int lda, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -387,6 +394,7 @@ static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, int m, int n, Scalar* A, int lda, int* dev_pivots, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -419,9 +427,6 @@ static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context, cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) { - // Note: The cuSolver functions called here appear not to be threadsafe. - // so we put a global lock around it. Since this function only puts a - // kernel on the stream, it is not a big performance hit. mutex_lock lock(handle_map_mutex); /* Launch the solver kernel. */ TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs, @@ -449,6 +454,7 @@ static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, int m, int n, Scalar* A, int lda, Scalar* tau, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -483,6 +489,7 @@ static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver, int m, int n, int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, Scalar* dev_c, int ldc, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR( @@ -526,6 +533,7 @@ static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver, cusolverDnHandle_t cusolver_dn_handle, int m, int n, int k, Scalar* dev_a, int lda, const Scalar* dev_tau, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k, @@ -606,17 +614,13 @@ static inline Status GesvdImpl( OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle, signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda, Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) { + mutex_lock lock(handle_map_mutex); /* Get amount of workspace memory required. */ int lwork; TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork)); /* Allocate device memory for workspace. */ auto dev_workspace = cuda_solver->GetScratchSpace(lwork, "", /* on_host */ false); - // Note: The cuSolver functions called here appear not to be threadsafe. - // so we put a global lock around it. Since this function only puts a - // kernel on the stream, it is not a big performance hit. - mutex_lock lock(handle_map_mutex); - /* Launch the solver kernel. */ TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n, CUDAComplex(A), lda, S, CUDAComplex(U), ldu, CUDAComplex(VT), ldvt, @@ -655,6 +659,7 @@ static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver, int lda, int* dev_pivots, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT::type; ScratchSpace dev_a_dev_ptrs = cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", @@ -689,6 +694,7 @@ static inline Status GetrsBatchedImpl( const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_b_dev_ptrs[], int ldb, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT::type; ScratchSpace dev_a_dev_ptrs = cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", @@ -734,6 +740,7 @@ static inline Status GetriBatchedImpl( cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT::type; ScratchSpace dev_a_dev_ptrs = cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", @@ -776,6 +783,7 @@ static inline Status MatInvBatchedImpl( cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { + mutex_lock lock(handle_map_mutex); using CudaScalar = typename CUDAComplexT::type; ScratchSpace dev_a_dev_ptrs = cuda_solver->GetScratchSpace(sizeof(CudaScalar*) * batch_size, "", diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 58cf36f454c..0af2e523626 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -493,6 +493,18 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "unique_dataset_op", + srcs = ["unique_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + tf_kernel_library( name = "dataset_ops", deps = [ @@ -526,6 +538,7 @@ tf_kernel_library( ":take_dataset_op", ":tensor_dataset_op", ":tensor_slice_dataset_op", + ":unique_dataset_op", ":zip_dataset_op", ], ) diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index fe0e498a3b7..6735373aaca 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -55,10 +55,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { *output = nullptr; -#define HANDLE_TYPE(T) \ - case DataTypeToEnum::value: { \ - *output = new Dataset(batch_size, row_shape, input); \ - break; \ +#define HANDLE_TYPE(T) \ + case DataTypeToEnum::value: { \ + *output = new Dataset(ctx, batch_size, row_shape, input); \ + break; \ } switch (input->output_dtypes()[0]) { @@ -75,11 +75,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { private: // TODO(mrry): Push the templated code down to the raw copying routine. template - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(int64 batch_size, const PartialTensorShape& row_shape, - const DatasetBase* input) - : batch_size_(batch_size), row_shape_(row_shape), input_(input) { + Dataset(OpKernelContext* ctx, int64 batch_size, + const PartialTensorShape& row_shape, const DatasetBase* input) + : GraphDatasetBase(ctx), + batch_size_(batch_size), + row_shape_(row_shape), + input_(input) { input_->Ref(); output_shapes_.reserve(3); @@ -112,6 +115,25 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { ")::Dataset"); } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_node; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + Node* batch_size_node; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); + Node* row_shape_node; + std::vector row_shape; + row_shape.reserve( + row_shape_.dims()); // not an unknown rank PartialTensorShape + for (int i = 0; i < row_shape_.dims(); i++) + row_shape.emplace_back(row_shape_.dim_size(i)); + TF_RETURN_IF_ERROR(b->AddVector(row_shape, &row_shape_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_node, batch_size_node, row_shape_node}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator> { public: @@ -242,6 +264,20 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(Iterator::SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(Iterator::RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + private: mutex mu_; std::unique_ptr input_impl_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 84ba0514687..413d0c0f578 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -64,20 +64,23 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { std::move(other_arguments), &captured_func)); - *output = - new Dataset(input, std::move(initial_state), std::move(captured_func), - state_types_, output_types_, output_shapes_); + *output = new Dataset(ctx, input, func_, std::move(initial_state), + std::move(captured_func), state_types_, output_types_, + output_shapes_); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(const DatasetBase* input, std::vector initial_state, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, std::vector initial_state, std::unique_ptr captured_func, const DataTypeVector& state_types, const DataTypeVector& output_types, const std::vector& output_shapes) - : input_(input), + : GraphDatasetBase(ctx), + input_(input), + func_(func), initial_state_(std::move(initial_state)), captured_func_(std::move(captured_func)), state_types_(state_types), @@ -103,6 +106,45 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "ScanDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + Node* input_node; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + std::vector initial_state_nodes; + initial_state_nodes.reserve(initial_state_.size()); + for (const Tensor& t : initial_state_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + initial_state_nodes.emplace_back(node); + } + std::vector other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(func_, &f); + AttrValue state_types; + b->BuildAttrValue(state_types_, &state_types); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {{0, input_node}}, + {{1, initial_state_nodes}, {2, other_arguments}}, + {{"f", f}, + {"Tstate", state_types}, + {"Targuments", other_arguments_types_attr}}, + output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -185,6 +227,38 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { return s; } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + if (!state_.empty()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("state_size"), state_.size())); + for (int idx = 0; idx < state_.size(); idx++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("state[", idx, "]")), state_[idx])); + } + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + if (reader->Contains(full_name("state_size"))) { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("state_size"), &size)); + state_.resize(size); + for (int idx = 0; idx < size; idx++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("state[", idx, "]")), &state_[idx])); + } + } + return Status::OK(); + } + private: mutex mu_; const std::unique_ptr input_impl_ GUARDED_BY(mu_); @@ -192,6 +266,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; + const NameAttrList func_; const std::vector initial_state_; const std::unique_ptr captured_func_; const DataTypeVector state_types_; diff --git a/tensorflow/core/kernels/data/unique_dataset_op.cc b/tensorflow/core/kernels/data/unique_dataset_op.cc new file mode 100644 index 00000000000..8401b731175 --- /dev/null +++ b/tensorflow/core/kernels/data/unique_dataset_op.cc @@ -0,0 +1,219 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class UniqueDatasetOp : public UnaryDatasetOpKernel { + public: + explicit UniqueDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OP_REQUIRES(ctx, input->output_dtypes().size() == 1, + errors::InvalidArgument("UniqueDataset only supports " + "inputs with a single component.")); + + DataType input_dtype = input->output_dtypes()[0]; + OP_REQUIRES(ctx, + input_dtype == DT_INT32 || input_dtype == DT_INT64 || + input_dtype == DT_STRING, + errors::InvalidArgument( + "UniqueDataset only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.")); + + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input) + : GraphDatasetBase(ctx), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Unique")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() override { + return strings::StrCat("UniqueDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const typename Iterator::Params& params) + : DatasetIterator(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool saw_new_value; + do { + saw_new_value = false; + out_tensors->clear(); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + break; + } + DCHECK_EQ(1, out_tensors->size()); + saw_new_value = unique_elements_.insert((*out_tensors)[0]).second; + } while (!saw_new_value); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("unique_elements_size"), unique_elements_.size())); + size_t i = 0; + for (const Tensor& t : unique_elements_) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("unique_elements[", i++, "]")), t)); + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + int64 num_unique_elements; + unique_elements_.clear(); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"), + &num_unique_elements)); + for (int64 i = 0; i < num_unique_elements; ++i) { + Tensor unique_element; + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("unique_elements[", i, "]")), + &unique_element)); + auto insert_result = unique_elements_.insert(unique_element); + if (!insert_result.second) { + return errors::InvalidArgument( + "Checkpoint contained two unique elements with the same " + "value."); + } + } + return Status::OK(); + } + + private: + struct TensorHash { + size_t operator()(const Tensor& t) const { + if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) { + return Hash64(t.tensor_data().data(), t.tensor_data().size()); + } else { + DCHECK_EQ(DT_STRING, t.dtype()); + auto flat_t = t.flat(); + uint64 hash = 0; + for (int64 i = 0; i < t.NumElements(); ++i) { + hash = Hash64Combine(hash, Hash64(flat_t(i))); + } + return static_cast(hash); + } + } + }; + + struct TensorKeyEqual { + bool operator()(const Tensor& lhs, const Tensor& rhs) const { + if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) { + return false; + } + switch (lhs.dtype()) { +#define HANDLE_TYPE(T) \ + case T: \ + do { \ + auto lhs_flat = lhs.flat::Type>(); \ + auto rhs_flat = rhs.flat::Type>(); \ + for (int64 i = 0; i < lhs.NumElements(); ++i) { \ + if (lhs_flat(i) != rhs_flat(i)) { \ + return false; \ + } \ + } \ + return true; \ + } while (0) + + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_INT64); + HANDLE_TYPE(DT_STRING); + default: + LOG(FATAL) << "UniqueDataset unhandled data type: " + << DataTypeString(lhs.dtype()); + } + } + }; + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::unordered_set unique_elements_ + GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU), + UniqueDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 2072e0df57f..66e5c163288 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -558,6 +558,16 @@ filename: A path on the filesystem where we should cache the dataset. Note: this will be a directory. )doc"); +REGISTER_OP("UniqueDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the unique elements of `input_dataset`. +)doc"); + REGISTER_OP("TextLineDataset") .Input("filenames: string") .Input("compression_type: string") diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index c2533b4314e..86943d34a69 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/lib/core/errors.h" @@ -327,6 +329,57 @@ Status CurlHttpRequest::SetResultBuffer(std::vector* out_buffer) { return Status::OK(); } +Status CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) { + CHECK(buffer != nullptr); + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + + direct_response_ = DirectResponseState{buffer, size, 0}; + + libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA, + reinterpret_cast(this)); + libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION, + &CurlHttpRequest::WriteCallbackDirect); + return Status::OK(); +} + +size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size, + size_t nmemb, void* userdata) { + CHECK(ptr != nullptr); + auto that = reinterpret_cast(userdata); + DirectResponseState* state = &that->direct_response_; + CHECK(state->buffer_ != nullptr); + CHECK(state->bytes_transferred_ <= state->buffer_size_); + + size_t curl_bytes_received = size * nmemb; + size_t user_buffer_bytes_available = + state->buffer_size_ - state->bytes_transferred_; + + // The HTTP server may send a response body that is longer than what we + // expected. We must not use CHECK() for this situation, because that would + // imply a code bug (in this client code) where none exists; the violation of + // expectations would have been caused by the server, not the client. So we + // report a log warning, if an HTTP server is misbehaving. + if (curl_bytes_received > user_buffer_bytes_available) { + LOG(WARNING) << "The HTTP response body that we received is longer than we " + "requested or expected. " + << "Total bytes requested: " << state->buffer_size_ + << " Bytes received (so far) in HTTP response body: " + << (state->bytes_transferred_ + curl_bytes_received); + } + + size_t bytes_to_copy = + std::min(curl_bytes_received, user_buffer_bytes_available); + memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy); + state->bytes_transferred_ += bytes_to_copy; + return bytes_to_copy; +} + +size_t CurlHttpRequest::GetResultBufferDirectBytesTransferred() { + CHECK(direct_response_.buffer_ != nullptr); + return direct_response_.bytes_transferred_; +} + Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity, uint32 total) { TF_RETURN_IF_ERROR(CheckInitialized()); diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h index e4c91dac8d2..0686b692cb5 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.h +++ b/tensorflow/core/platform/cloud/curl_http_request.h @@ -103,6 +103,26 @@ class CurlHttpRequest : public HttpRequest { /// read. Existing content of the vector will be cleared. Status SetResultBuffer(std::vector* out_buffer) override; + /// \brief Specifies the buffer for receiving the response body, when the + /// caller knows the maximum size of the response body. + /// + /// This method allows the caller to receive the response body without an + /// additional intermediate buffer allocation and copy. This method should + /// be called before calling Send(). After Send() has succeeded, the caller + /// should use the GetResultBufferDirectBytesTransferred() method in order + /// to learn how many bytes were transferred. + /// + /// Using this method is mutually exclusive with using SetResultBuffer(). + Status SetResultBufferDirect(char* buffer, size_t size) override; + + /// \brief Returns the number of bytes (of the response body) that were + /// transferred, when using the SetResultBufferDirect() method. The returned + /// value will always be less than or equal to the 'size' parameter that + /// was passed to SetResultBufferDirect(). If the actual HTTP response body + /// was greater than 'size' bytes, then this transfer method will only copy + /// the first 'size' bytes, and the rest will be ignored. + size_t GetResultBufferDirectBytesTransferred() override; + /// \brief Returns the response headers of a completed request. /// /// If the header is not found, returns an empty string. @@ -127,6 +147,10 @@ class CurlHttpRequest : public HttpRequest { /// A write callback in the form which can be accepted by libcurl. static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb, void* userdata); + + /// Processes response body content received when using SetResultBufferDirect. + static size_t WriteCallbackDirect(const void* ptr, size_t size, size_t nmemb, + void* userdata); /// A read callback in the form which can be accepted by libcurl. static size_t ReadCallback(void* ptr, size_t size, size_t nmemb, FILE* userdata); @@ -150,6 +174,14 @@ class CurlHttpRequest : public HttpRequest { size_t post_body_read_ = 0; std::vector* response_buffer_ = nullptr; + + struct DirectResponseState { + char* buffer_; + size_t buffer_size_; + size_t bytes_transferred_; + }; + DirectResponseState direct_response_ = {}; + CURL* curl_ = nullptr; curl_slist* curl_headers_ = nullptr; curl_slist* resolve_list_ = nullptr; diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc index 2d3e46edaf8..d108849c0fe 100644 --- a/tensorflow/core/platform/cloud/curl_http_request_test.cc +++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc @@ -288,6 +288,39 @@ TEST(CurlHttpRequestTest, GetRequest) { EXPECT_EQ(200, http_request.GetResponseCode()); } +TEST(CurlHttpRequestTest, GetRequest_Direct) { + FakeLibCurl libcurl("get response", 200); + CurlHttpRequest http_request(&libcurl); + TF_EXPECT_OK(http_request.Init()); + + std::vector scratch(100, 0); + + TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); + TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer")); + TF_EXPECT_OK(http_request.SetRange(100, 199)); + TF_EXPECT_OK( + http_request.SetResultBufferDirect(scratch.data(), scratch.capacity())); + TF_EXPECT_OK(http_request.Send()); + + string expected_response = "get response"; + size_t response_bytes_transferred = + http_request.GetResultBufferDirectBytesTransferred(); + EXPECT_EQ(response_bytes_transferred, expected_response.size()); + EXPECT_EQ( + "get response", + string(scratch.begin(), scratch.begin() + response_bytes_transferred)); + + // Check interactions with libcurl. + EXPECT_TRUE(libcurl.is_initialized_); + EXPECT_EQ("http://www.testuri.com", libcurl.url_); + EXPECT_EQ("100-199", libcurl.range_); + EXPECT_EQ("", libcurl.custom_request_); + EXPECT_EQ(1, libcurl.headers_->size()); + EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl.headers_)[0]); + EXPECT_FALSE(libcurl.is_post_); + EXPECT_EQ(200, http_request.GetResponseCode()); +} + TEST(CurlHttpRequestTest, GetRequest_Empty) { FakeLibCurl libcurl("", 200); CurlHttpRequest http_request(&libcurl); diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc index e1afc7b308e..e6fa93890f9 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.cc +++ b/tensorflow/core/platform/cloud/file_block_cache.cc @@ -123,8 +123,12 @@ Status FileBlockCache::MaybeFetch(const Key& key, case FetchState::CREATED: block->state = FetchState::FETCHING; block->mu.unlock(); // Release the lock while making the API call. - status.Update( - block_fetcher_(key.first, key.second, block_size_, &block->data)); + block->data.clear(); + block->data.resize(block_size_, 0); + size_t bytes_transferred; + status.Update(block_fetcher_(key.first, key.second, block_size_, + block->data.data(), &bytes_transferred)); + block->data.resize(bytes_transferred, 0); block->mu.lock(); // Reacquire the lock immediately afterwards if (status.ok()) { downloaded_block = true; @@ -150,15 +154,15 @@ Status FileBlockCache::MaybeFetch(const Key& key, } Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, - std::vector* out) { - out->clear(); + char* buffer, size_t* bytes_transferred) { + *bytes_transferred = 0; if (n == 0) { return Status::OK(); } if (block_size_ == 0 || max_bytes_ == 0) { // The cache is effectively disabled, so we pass the read through to the // fetcher without breaking it up into blocks. - return block_fetcher_(filename, offset, n, out); + return block_fetcher_(filename, offset, n, buffer, bytes_transferred); } // Calculate the block-aligned start and end of the read. size_t start = block_size_ * (offset / block_size_); @@ -166,6 +170,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, if (finish < offset + n) { finish += block_size_; } + size_t total_bytes_transferred = 0; // Now iterate through the blocks, reading them one at a time. for (size_t pos = start; pos < finish; pos += block_size_) { Key key = std::make_pair(filename, pos); @@ -181,6 +186,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, // The requested offset is at or beyond the end of the file. This can // happen if `offset` is not block-aligned, and the read returns the last // block in the file, which does not extend all the way out to `offset`. + *bytes_transferred = total_bytes_transferred; return errors::OutOfRange("EOF at offset ", offset, " in file ", filename, " at position ", pos, "with data size ", data.size()); @@ -196,13 +202,16 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, end -= (pos + data.size()) - (offset + n); } if (begin < end) { - out->insert(out->end(), begin, end); + size_t bytes_to_copy = end - begin; + memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy); + total_bytes_transferred += bytes_to_copy; } if (data.size() < block_size_) { // The block was a partial block and thus signals EOF at its upper bound. break; } } + *bytes_transferred = total_bytes_transferred; return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h index 36dbf9db832..74e792a6251 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.h +++ b/tensorflow/core/platform/cloud/file_block_cache.h @@ -43,8 +43,9 @@ class FileBlockCache { /// cache is constructed. The returned Status should be OK as long as the /// read from the remote filesystem succeeded (similar to the semantics of the /// read(2) system call). - typedef std::function*)> + typedef std::function BlockFetcher; FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness, @@ -83,8 +84,8 @@ class FileBlockCache { /// placed in `out`. /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed /// in `out`). - Status Read(const string& filename, size_t offset, size_t n, - std::vector* out); + Status Read(const string& filename, size_t offset, size_t n, char* buffer, + size_t* bytes_transferred); /// Remove all cached blocks for `filename`. void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_); diff --git a/tensorflow/core/platform/cloud/file_block_cache_test.cc b/tensorflow/core/platform/cloud/file_block_cache_test.cc index bebed5af10d..ae87e0de29b 100644 --- a/tensorflow/core/platform/cloud/file_block_cache_test.cc +++ b/tensorflow/core/platform/cloud/file_block_cache_test.cc @@ -25,6 +25,18 @@ limitations under the License. namespace tensorflow { namespace { +Status ReadCache(FileBlockCache* cache, const string& filename, size_t offset, + size_t n, std::vector* out) { + out->clear(); + out->resize(n, 0); + size_t bytes_transferred = 0; + Status status = + cache->Read(filename, offset, n, out->data(), &bytes_transferred); + EXPECT_LE(bytes_transferred, n); + out->resize(bytes_transferred, n); + return status; +} + TEST(FileBlockCacheTest, PassThrough) { const string want_filename = "foo/bar"; const size_t want_offset = 42; @@ -32,12 +44,13 @@ TEST(FileBlockCacheTest, PassThrough) { int calls = 0; auto fetcher = [&calls, want_filename, want_offset, want_n]( const string& got_filename, size_t got_offset, - size_t got_n, std::vector* out) { + size_t got_n, char* buffer, size_t* bytes_transferred) { EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_offset, want_offset); EXPECT_EQ(got_n, want_n); calls++; - out->resize(got_n, 'x'); + memset(buffer, 'x', got_n); + *bytes_transferred = got_n; return Status::OK(); }; // If block_size, max_bytes, or both are zero, the cache is a pass-through. @@ -45,11 +58,11 @@ TEST(FileBlockCacheTest, PassThrough) { FileBlockCache cache2(0, 1, 0, fetcher); FileBlockCache cache3(0, 0, 0, fetcher); std::vector out; - TF_EXPECT_OK(cache1.Read(want_filename, want_offset, want_n, &out)); + TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out)); EXPECT_EQ(calls, 1); - TF_EXPECT_OK(cache2.Read(want_filename, want_offset, want_n, &out)); + TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out)); EXPECT_EQ(calls, 2); - TF_EXPECT_OK(cache3.Read(want_filename, want_offset, want_n, &out)); + TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out)); EXPECT_EQ(calls, 3); } @@ -63,13 +76,13 @@ TEST(FileBlockCacheTest, BlockAlignment) { } // The fetcher just fetches slices of the buffer. auto fetcher = [&buf](const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { if (offset < buf.size()) { - if (offset + n > buf.size()) { - out->insert(out->end(), buf.begin() + offset, buf.end()); - } else { - out->insert(out->end(), buf.begin() + offset, buf.begin() + offset + n); - } + size_t bytes_to_copy = std::min(buf.size() - offset, n); + memcpy(buffer, buf.data() + offset, bytes_to_copy); + *bytes_transferred = bytes_to_copy; + } else { + *bytes_transferred = 0; } return Status::OK(); }; @@ -80,7 +93,7 @@ TEST(FileBlockCacheTest, BlockAlignment) { for (size_t offset = 0; offset < 10; offset++) { for (size_t n = block_size - 2; n <= block_size + 2; n++) { std::vector got; - TF_EXPECT_OK(cache.Read("", offset, n, &got)); + TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got)); // Verify the size of the read. if (offset + n <= size) { // Expect a full read. @@ -108,24 +121,27 @@ TEST(FileBlockCacheTest, CacheHits) { const size_t block_size = 16; std::set calls; auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, std::vector* out) { + size_t n, char* buffer, + size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset; calls.insert(offset); - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; const uint32 block_count = 256; FileBlockCache cache(block_size, block_count * block_size, 0, fetcher); std::vector out; + out.resize(block_count, 0); // The cache has space for `block_count` blocks. The loop with i = 0 should // fill the cache, and the loop with i = 1 should be all cache hits. The // fetcher checks that it is called once and only once for each offset (to // fetch the corresponding block). for (int i = 0; i < 2; i++) { for (int j = 0; j < block_count; j++) { - TF_EXPECT_OK(cache.Read("", block_size * j, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out)); } } } @@ -138,36 +154,39 @@ TEST(FileBlockCacheTest, OutOfRange) { bool second_block = false; auto fetcher = [block_size, file_size, &first_block, &second_block]( const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); + size_t bytes_to_copy = 0; if (offset == 0) { // The first block (16 bytes) of the file. - out->resize(n, 'x'); + memset(buffer, 'x', n); + bytes_to_copy = n; first_block = true; } else if (offset == block_size) { // The second block (8 bytes) of the file. - out->resize(file_size - block_size, 'x'); + bytes_to_copy = file_size - block_size; + memset(buffer, 'x', bytes_to_copy); second_block = true; } + *bytes_transferred = bytes_to_copy; return Status::OK(); }; FileBlockCache cache(block_size, block_size, 0, fetcher); std::vector out; // Reading the first 16 bytes should be fine. - TF_EXPECT_OK(cache.Read("", 0, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out)); EXPECT_TRUE(first_block); EXPECT_EQ(out.size(), block_size); // Reading at offset file_size + 4 will read the second block (since the read // at file_size + 4 = 28 will be aligned to an offset of 16) but will return // OutOfRange because the offset is past the end of the 24-byte file. - Status status = cache.Read("", file_size + 4, 4, &out); + Status status = ReadCache(&cache, "", file_size + 4, 4, &out); EXPECT_EQ(status.code(), error::OUT_OF_RANGE); EXPECT_TRUE(second_block); - EXPECT_EQ(out.size(), 0); // Reading the second full block will return 8 bytes, from a cache hit. second_block = false; - TF_EXPECT_OK(cache.Read("", block_size, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out)); EXPECT_FALSE(second_block); EXPECT_EQ(out.size(), file_size - block_size); } @@ -178,20 +197,22 @@ TEST(FileBlockCacheTest, Inconsistent) { const size_t block_size = 16; // This fetcher returns OK but only fills in one byte for any offset. auto fetcher = [block_size](const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); - out->resize(1, 'x'); + EXPECT_GE(n, 1); + memset(buffer, 'x', 1); + *bytes_transferred = 1; return Status::OK(); }; FileBlockCache cache(block_size, 2 * block_size, 0, fetcher); std::vector out; // Read the second block; this should yield an OK status and a single byte. - TF_EXPECT_OK(cache.Read("", block_size, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out)); EXPECT_EQ(out.size(), 1); // Now read the first block; this should yield an INTERNAL error because we // had already cached a partial block at a later position. - Status status = cache.Read("", 0, block_size, &out); + Status status = ReadCache(&cache, "", 0, block_size, &out); EXPECT_EQ(status.code(), error::INTERNAL); } @@ -199,14 +220,16 @@ TEST(FileBlockCacheTest, LRU) { const size_t block_size = 16; std::list calls; auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, std::vector* out) { + size_t n, char* buffer, + size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_FALSE(calls.empty()) << "at offset = " << offset; if (!calls.empty()) { EXPECT_EQ(offset, calls.front()); calls.pop_front(); } - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; const uint32 block_count = 2; @@ -216,38 +239,39 @@ TEST(FileBlockCacheTest, LRU) { // fetcher calls that the cache makes. calls.push_back(0); // Cache miss - drains an element from `calls`. - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); // Cache hit - does not drain an element from `calls`. - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); calls.push_back(block_size); // Cache miss followed by cache hit. - TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); - TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); calls.push_back(2 * block_size); // Cache miss followed by cache hit. Causes eviction of LRU element. - TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); - TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); // LRU element was at offset 0. Cache miss. calls.push_back(0); - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); // Element at 2 * block_size is still in cache, and this read should update // its position in the LRU list so it doesn't get evicted by the next read. - TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); // Element at block_size was evicted. Reading this element will also cause // the LRU element (at 0) to be evicted. calls.push_back(block_size); - TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); // Element at 0 was evicted again. calls.push_back(0); - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); } TEST(FileBlockCacheTest, MaxStaleness) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { calls++; - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; std::vector out; @@ -256,14 +280,14 @@ TEST(FileBlockCacheTest, MaxStaleness) { // expected. FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get()); // Execute the first read to load the block. - TF_EXPECT_OK(cache1.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out)); EXPECT_EQ(calls, 1); // Now advance the clock one second at a time and redo the read. The call // count should advance every 3 seconds (i.e. every time the staleness is // greater than 2). for (int i = 1; i <= 10; i++) { env->SetNowSeconds(i + 1); - TF_EXPECT_OK(cache1.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out)); EXPECT_EQ(calls, 1 + i / 3); } // Now create a cache with max staleness of 0, and verify that it also works @@ -272,27 +296,27 @@ TEST(FileBlockCacheTest, MaxStaleness) { env->SetNowSeconds(0); FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get()); // Execute the first read to load the block. - TF_EXPECT_OK(cache2.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out)); EXPECT_EQ(calls, 1); // Advance the clock by a huge amount and verify that the cached block is // used to satisfy the read. env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun. - TF_EXPECT_OK(cache2.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out)); EXPECT_EQ(calls, 1); } TEST(FileBlockCacheTest, RemoveFile) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { calls++; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; if (offset > 0) { // The first block is lower case and all subsequent blocks are upper case. c = toupper(c); } - out->clear(); - out->resize(n, c); + memset(buffer, c, n); + *bytes_transferred = n; return Status::OK(); }; // This cache has space for 4 blocks; we'll read from two files. @@ -304,41 +328,41 @@ TEST(FileBlockCacheTest, RemoveFile) { std::vector A(n, 'A'); std::vector B(n, 'B'); // Fill the cache. - TF_EXPECT_OK(cache.Read("a", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); EXPECT_EQ(out, a); EXPECT_EQ(calls, 1); - TF_EXPECT_OK(cache.Read("a", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); EXPECT_EQ(out, A); EXPECT_EQ(calls, 2); - TF_EXPECT_OK(cache.Read("b", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); EXPECT_EQ(out, b); EXPECT_EQ(calls, 3); - TF_EXPECT_OK(cache.Read("b", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); EXPECT_EQ(out, B); EXPECT_EQ(calls, 4); // All four blocks should be in the cache now. - TF_EXPECT_OK(cache.Read("a", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); EXPECT_EQ(out, a); - TF_EXPECT_OK(cache.Read("a", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); EXPECT_EQ(out, A); - TF_EXPECT_OK(cache.Read("b", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); EXPECT_EQ(out, b); - TF_EXPECT_OK(cache.Read("b", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); EXPECT_EQ(out, B); EXPECT_EQ(calls, 4); // Remove the blocks from "a". cache.RemoveFile("a"); // Both blocks from "b" should still be there. - TF_EXPECT_OK(cache.Read("b", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); EXPECT_EQ(out, b); - TF_EXPECT_OK(cache.Read("b", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); EXPECT_EQ(out, B); EXPECT_EQ(calls, 4); // The blocks from "a" should not be there. - TF_EXPECT_OK(cache.Read("a", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); EXPECT_EQ(out, a); EXPECT_EQ(calls, 5); - TF_EXPECT_OK(cache.Read("a", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); EXPECT_EQ(out, A); EXPECT_EQ(calls, 6); } @@ -346,10 +370,10 @@ TEST(FileBlockCacheTest, RemoveFile) { TEST(FileBlockCacheTest, Prune) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { calls++; - out->clear(); - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; std::vector out; @@ -360,20 +384,20 @@ TEST(FileBlockCacheTest, Prune) { FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get()); // Read three blocks into the cache, and advance the timestamp by one second // with each read. Start with a block of "a" at the current timestamp `now`. - TF_EXPECT_OK(cache.Read("a", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out)); // Now load a block of a different file "b" at timestamp `now` + 1 env->SetNowSeconds(now + 1); - TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); // Now load a different block of file "a" at timestamp `now` + 1. When the // first block of "a" expires, this block should also be removed because it // also belongs to file "a". - TF_EXPECT_OK(cache.Read("a", 8, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out)); // Ensure that all blocks are in the cache (i.e. reads are cache hits). EXPECT_EQ(cache.CacheSize(), 24); EXPECT_EQ(calls, 3); - TF_EXPECT_OK(cache.Read("a", 0, 1, &out)); - TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); - TF_EXPECT_OK(cache.Read("a", 8, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out)); EXPECT_EQ(calls, 3); // Advance the fake timestamp so that "a" becomes stale via its first block. env->SetNowSeconds(now + 2); @@ -389,7 +413,7 @@ TEST(FileBlockCacheTest, Prune) { // There should be one block left in the cache, and it should be the first // block of "b". EXPECT_EQ(cache.CacheSize(), 8); - TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); EXPECT_EQ(calls, 3); // Advance the fake time to `now` + 3, at which point "b" becomes stale. env->SetNowSeconds(now + 3); @@ -409,14 +433,14 @@ TEST(FileBlockCacheTest, ParallelReads) { const int callers = 4; BlockingCounter counter(callers); auto fetcher = [&counter](const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { counter.DecrementCount(); if (!counter.WaitFor(std::chrono::seconds(10))) { // This avoids having the test time out, which is harder to debug. return errors::FailedPrecondition("desired concurrency not reached"); } - out->clear(); - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; const int block_size = 8; @@ -426,7 +450,8 @@ TEST(FileBlockCacheTest, ParallelReads) { threads.emplace_back( Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() { std::vector out; - TF_EXPECT_OK(cache.Read("a", i * block_size, block_size, &out)); + TF_EXPECT_OK( + ReadCache(&cache, "a", i * block_size, block_size, &out)); std::vector x(block_size, 'x'); EXPECT_EQ(out, x); })); @@ -443,11 +468,12 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) { Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( const string& filename, size_t offset, size_t n, - std::vector* out) { + char* buffer, size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset, 0); num_requests++; - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; notification.Notify(); // Wait for other thread to issue read. Env::Default()->SleepForMicroseconds(100000); // 0.1 secs @@ -458,17 +484,16 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) { std::unique_ptr concurrent( Env::Default()->StartThread({}, "concurrent", [&cache, block_size] { std::vector out; - TF_EXPECT_OK(cache.Read("", 0, block_size / 2, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out)); EXPECT_EQ(out.size(), block_size / 2); })); EXPECT_TRUE(WaitForNotificationWithTimeout(¬ification, 10000)) << "Timeout waiting for concurrent thread to start."; std::vector out; - TF_EXPECT_OK(cache.Read("", block_size / 2, block_size / 2, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out)); EXPECT_EQ(out.size(), block_size / 2); EXPECT_EQ(1, num_requests); } - } // namespace } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc index 2c3819f1e2e..c96d364228e 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc @@ -58,6 +58,10 @@ class TestHttpRequest : public HttpRequest { Status SetResultBuffer(std::vector* out_buffer) override { return Status::OK(); } + Status SetResultBufferDirect(char* buffer, size_t size) override { + return Status::OK(); + } + size_t GetResultBufferDirectBytesTransferred() override { return 0; } string GetResponseHeader(const string& name) const override { return ""; } uint64 GetResponseCode() const override { return 0; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index a183fe6fa80..ec66ab01d6c 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -285,11 +285,11 @@ class GcsRandomAccessFile : public RandomAccessFile { Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { *result = StringPiece(); - std::vector out; - TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out)); - std::memcpy(scratch, out.data(), std::min(out.size(), n)); - *result = StringPiece(scratch, std::min(out.size(), n)); - if (result->size() < n) { + size_t bytes_transferred; + TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch, + &bytes_transferred)); + *result = StringPiece(scratch, bytes_transferred); + if (bytes_transferred < n) { // This is not an error per se. The RandomAccessFile interface expects // that Read returns OutOfRange if fewer bytes were read than requested. return errors::OutOfRange("EOF reached, ", result->size(), @@ -721,15 +721,17 @@ std::unique_ptr GcsFileSystem::MakeFileBlockCache( std::unique_ptr file_block_cache( new FileBlockCache(block_size, max_bytes, max_staleness, [this](const string& filename, size_t offset, size_t n, - std::vector* out) { - return LoadBufferFromGCS(filename, offset, n, out); + char* buffer, size_t* bytes_transferred) { + return LoadBufferFromGCS(filename, offset, n, buffer, + bytes_transferred); })); return file_block_cache; } // A helper function to actually read the data from GCS. Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, - size_t n, std::vector* out) { + size_t n, char* buffer, + size_t* bytes_transferred) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object)); @@ -739,21 +741,23 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, "/", request->EscapeString(object)))); TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(out)); + TF_RETURN_IF_ERROR(request->SetResultBufferDirect(buffer, n)); TF_RETURN_IF_ERROR( request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read)); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", bucket, "/", object); + size_t bytes_read = request->GetResultBufferDirectBytesTransferred(); + *bytes_transferred = bytes_read; VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ " - << offset << " of size: " << out->size(); + << offset << " of size: " << bytes_read; - if (out->size() < block_size()) { + if (bytes_read < block_size()) { // Check stat cache to see if we encountered an interrupted read. FileStatistics stat; if (stat_cache_->Lookup(filename, &stat)) { - if (offset + out->size() < stat.length) { + if (offset + bytes_read < stat.length) { return errors::Internal(strings::Printf( "File contents are inconsistent for file: %s @ %lu.", filename.c_str(), offset)); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index f4190b3f1ee..731f97a4aad 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -177,7 +177,7 @@ class GcsFileSystem : public FileSystem { /// Loads file contents from GCS for a given filename, offset, and length. Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n, - std::vector* out); + char* buffer, size_t* bytes_transferred); std::unique_ptr auth_provider_; std::unique_ptr http_request_factory_; diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h index 95a436c6229..6b13ac475eb 100644 --- a/tensorflow/core/platform/cloud/http_request.h +++ b/tensorflow/core/platform/cloud/http_request.h @@ -101,6 +101,20 @@ class HttpRequest { /// read. Existing content of the vector will be cleared. virtual Status SetResultBuffer(std::vector* out_buffer) = 0; + /// \brief Specifies the buffer for receiving the response body. + /// + /// This method should be used when a caller knows the upper bound of the + /// size of the response data. The caller provides a pre-allocated buffer + /// and its size. After the Send() method is called, the + /// GetResultBufferDirectBytesTransferred() method may be used to learn to the + /// number of bytes that were transferred using this method. + virtual Status SetResultBufferDirect(char* buffer, size_t size) = 0; + + /// \brief Returns the number of bytes transferred, when using + /// SetResultBufferDirect(). This method may only be used when using + /// SetResultBufferDirect(). + virtual size_t GetResultBufferDirectBytesTransferred() = 0; + /// \brief Returns the response headers of a completed request. /// /// If the header is not found, returns an empty string. diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h index f65c15dac77..72292dca53b 100644 --- a/tensorflow/core/platform/cloud/http_request_fake.h +++ b/tensorflow/core/platform/cloud/http_request_fake.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_ #define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_ +#include #include #include #include @@ -130,12 +131,25 @@ class FakeHttpRequest : public CurlHttpRequest { buffer_ = buffer; return Status::OK(); } + Status SetResultBufferDirect(char* buffer, size_t size) override { + direct_result_buffer_ = buffer; + direct_result_buffer_size_ = size; + return Status::OK(); + } + size_t GetResultBufferDirectBytesTransferred() override { + return direct_result_bytes_transferred_; + } Status Send() override { EXPECT_EQ(expected_request_, actual_request()) << "Unexpected HTTP request."; if (buffer_) { - buffer_->insert(buffer_->begin(), response_.c_str(), - response_.c_str() + response_.size()); + buffer_->insert(buffer_->begin(), response_.data(), + response_.data() + response_.size()); + } else if (direct_result_buffer_ != nullptr) { + size_t bytes_to_copy = + std::min(direct_result_buffer_size_, response_.size()); + memcpy(direct_result_buffer_, response_.data(), bytes_to_copy); + direct_result_bytes_transferred_ += bytes_to_copy; } return response_status_; } @@ -178,6 +192,9 @@ class FakeHttpRequest : public CurlHttpRequest { } std::vector* buffer_ = nullptr; + char* direct_result_buffer_ = nullptr; + size_t direct_result_buffer_size_ = 0; + size_t direct_result_bytes_transferred_ = 0; string expected_request_; string actual_uri_; string actual_request_; diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc index ba3c4e70907..8097624e09f 100644 --- a/tensorflow/core/platform/posix/env.cc +++ b/tensorflow/core/platform/posix/env.cc @@ -136,15 +136,19 @@ void Env::GetLocalTempDirectories(std::vector* list) { // Directories, in order of preference. If we find a dir that // exists, we stop adding other less-preferred dirs const char* candidates[] = { - // Non-null only during unittest/regtest - getenv("TEST_TMPDIR"), + // Non-null only during unittest/regtest + getenv("TEST_TMPDIR"), - // Explicitly-supplied temp dirs - getenv("TMPDIR"), - getenv("TMP"), + // Explicitly-supplied temp dirs + getenv("TMPDIR"), + getenv("TMP"), - // If all else fails - "/tmp", +#if defined(__ANDROID__) + "/data/local/tmp", +#endif + + // If all else fails + "/tmp", }; for (const char* d : candidates) { diff --git a/tensorflow/go/genop/api_def_map.go b/tensorflow/go/genop/api_def_map.go new file mode 100644 index 00000000000..07b689dbba2 --- /dev/null +++ b/tensorflow/go/genop/api_def_map.go @@ -0,0 +1,127 @@ +/* +Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +/* +#include +#include + +#include "tensorflow/c/c_api.h" +*/ +import "C" + +import ( + "errors" + "fmt" + "runtime" + "unsafe" + + "github.com/golang/protobuf/proto" + pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework" +) + +// Encapsulates a collection of API definitions. +// +// apiDefMap represents a map from operation name to corresponding +// ApiDef proto (see +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto +// for ApiDef proto definition). +type apiDefMap struct { + c *C.TF_ApiDefMap +} + +// Creates and returns a new apiDefMap instance. +// +// oplist is and OpList proto instance (see +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for OpList proto definition). + +func newAPIDefMap(oplist *pb.OpList) (*apiDefMap, error) { + // Create a buffer containing the serialized OpList. + opdefSerialized, err := proto.Marshal(oplist) + if err != nil { + return nil, fmt.Errorf("could not serialize OpDef for %s", oplist.String()) + } + data := C.CBytes(opdefSerialized) + defer C.free(data) + + opbuf := C.TF_NewBuffer() + defer C.TF_DeleteBuffer(opbuf) + opbuf.data = data + opbuf.length = C.size_t(len(opdefSerialized)) + + // Create ApiDefMap. + status := C.TF_NewStatus() + defer C.TF_DeleteStatus(status) + capimap := C.TF_NewApiDefMap(opbuf, status) + if C.TF_GetCode(status) != C.TF_OK { + return nil, errors.New(C.GoString(C.TF_Message(status))) + } + apimap := &apiDefMap{capimap} + runtime.SetFinalizer( + apimap, + func(a *apiDefMap) { + C.TF_DeleteApiDefMap(a.c) + }) + return apimap, nil +} + +// Updates apiDefMap with the overrides specified in `data`. +// +// data - ApiDef text proto. +func (m *apiDefMap) Put(data string) error { + cdata := C.CString(data) + defer C.free(unsafe.Pointer(cdata)) + status := C.TF_NewStatus() + defer C.TF_DeleteStatus(status) + C.TF_ApiDefMapPut(m.c, cdata, C.size_t(len(data)), status) + if C.TF_GetCode(status) != C.TF_OK { + return errors.New(C.GoString(C.TF_Message(status))) + } + return nil +} + +// Returns ApiDef proto instance for the TensorFlow operation +// named `opname`. +func (m *apiDefMap) Get(opname string) (*pb.ApiDef, error) { + cname := C.CString(opname) + defer C.free(unsafe.Pointer(cname)) + status := C.TF_NewStatus() + defer C.TF_DeleteStatus(status) + apidefBuf := C.TF_ApiDefMapGet( + m.c, cname, C.size_t(len(opname)), status) + defer C.TF_DeleteBuffer(apidefBuf) + if C.TF_GetCode(status) != C.TF_OK { + return nil, errors.New(C.GoString(C.TF_Message(status))) + } + if apidefBuf == nil { + return nil, fmt.Errorf("could not find ApiDef for %s", opname) + } + + var ( + apidef = new(pb.ApiDef) + size = int(apidefBuf.length) + // A []byte backed by C memory. + // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + data = (*[1 << 30]byte)(unsafe.Pointer(apidefBuf.data))[:size:size] + err = proto.Unmarshal(data, apidef) + ) + if err != nil { + return nil, err + } + return apidef, nil +} diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go index dec08dee1ca..82f7510f2ed 100644 --- a/tensorflow/go/genop/internal/genop.go +++ b/tensorflow/go/genop/internal/genop.go @@ -29,12 +29,18 @@ limitations under the License. // encountered. package internal -// #include "tensorflow/c/c_api.h" +/* +#include + +#include "tensorflow/c/c_api.h" +*/ import "C" import ( "fmt" "io" + "io/ioutil" + "path" "reflect" "strings" "text/template" @@ -47,15 +53,23 @@ import ( // GenerateFunctionsForRegisteredOps writes a Go source code file to w // containing functions for each TensorFlow operation registered in the address // space of the calling process. -func GenerateFunctionsForRegisteredOps(w io.Writer) error { - ops, err := registeredOps() +// apidefDirs should be a contain of directories containing api_def_*.pbtxt +// files to load. +func GenerateFunctionsForRegisteredOps( + w io.Writer, apidefDirs []string) error { + ops, apimap, err := registeredOps() if err != nil { return err } - return generateFunctionsForOps(w, ops) + for _, dir := range apidefDirs { + if err = updateAPIDefs(apimap, dir); err != nil { + return err + } + } + return generateFunctionsForOps(w, ops, apimap) } -func registeredOps() (*pb.OpList, error) { +func registeredOps() (*pb.OpList, *apiDefMap, error) { buf := C.TF_GetAllOpList() defer C.TF_DeleteBuffer(buf) var ( @@ -66,10 +80,31 @@ func registeredOps() (*pb.OpList, error) { data = (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:size:size] err = proto.Unmarshal(data, list) ) - return list, err + if err != nil { + return nil, nil, err + } + apimap, err := newAPIDefMap(list) + return list, apimap, err } -func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error { +func updateAPIDefs(m *apiDefMap, dir string) error { + files, err := ioutil.ReadDir(dir) + if err != nil { + return err + } + for _, file := range files { + data, err := ioutil.ReadFile(path.Join(dir, file.Name())) + if err != nil { + return fmt.Errorf("failed to read %q: %v", file.Name(), err) + } + if err = m.Put(string(data)); err != nil { + return fmt.Errorf("failed to process %q: %v", file.Name(), err) + } + } + return nil +} + +func generateFunctionsForOps(w io.Writer, ops *pb.OpList, apimap *apiDefMap) error { thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath() if err := tmplHeader.Execute(w, thisPackage); err != nil { return err @@ -83,14 +118,18 @@ func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error { if blacklist[op.Name] { continue } - if err := generateFunctionForOp(w, op); err != nil { + apidef, err := apimap.Get(op.Name) + if err != nil { + return err + } + if err := generateFunctionForOp(w, op, apidef); err != nil { return err } } return nil } -func generateFunctionForOp(w io.Writer, op *pb.OpDef) error { +func generateFunctionForOp(w io.Writer, op *pb.OpDef, apidef *pb.ApiDef) error { if strings.HasPrefix(op.Name, "_") { // Internal operation return nil } @@ -112,12 +151,16 @@ func generateFunctionForOp(w io.Writer, op *pb.OpDef) error { return nil } } - if op.Summary == "" { + if apidef.Summary == "" { // Undocumented operation, perhaps a sign of not being ready to // export. return nil } - return tmplOp.Execute(w, newTmplArgs(op)) + tmplArgs, err := newTmplArgs(op, apidef) + if err != nil { + return err + } + return tmplOp.Execute(w, tmplArgs) } var ( @@ -172,7 +215,7 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in type {{.Op.Name}}Attr func(optionalAttr) {{range .OptionalAttrs}} -// {{$.Op.Name}}{{CamelCase .Name}} sets the optional {{.Name}} attribute to value. +// {{$.Op.Name}}{{CamelCase .RenameTo}} sets the optional {{.RenameTo}} attribute to value. {{- if .Description}} // // value: {{MakeComment .Description}} @@ -180,9 +223,9 @@ type {{.Op.Name}}Attr func(optionalAttr) // If not specified, defaults to {{StripLeadingColon .DefaultValue}} {{- if .HasMinimum}} // -// {{if IsListAttr .}}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} +// {{if .IsListAttr }}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}} {{- end}} -func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { +func {{$.Op.Name}}{{CamelCase .RenameTo}}(value {{GoType .Type}}) {{$.Op.Name}}Attr { return func(m optionalAttr) { m[{{printf "%q" .Name}}] = value } @@ -192,14 +235,14 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {{- /* Create a godoc friendly comment. */ -}} -// {{MakeComment .Op.Summary}} +// {{MakeComment .APIDef.Summary}} {{- with .Op.Deprecation}} // // DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}} {{- end -}} -{{- with .Op.Description}} +{{- with .APIDef.Description}} // // {{MakeComment .}} {{- end -}} @@ -207,11 +250,11 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {{- if .DescribeArguments}} // // Arguments: -{{- range .Op.InputArg}} -// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}} +{{- range .InArgsReordered}} +// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} {{- end -}} {{- range .RequiredAttrs}} -// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}} +// {{if .Description}}{{Identifier .RenameTo}}: {{MakeComment .Description}}{{end}} {{- end -}} {{- end -}} @@ -221,12 +264,12 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {{- else }} {{- if .DescribeOutputs}} // -{{- if ((len .Op.OutputArg) eq 1) }} -// Returns {{range .Op.OutputArg}}{{MakeComment .Description}}{{end}} +{{- if ((len .OutArgs) eq 1) }} +// Returns {{range .OutArgs}}{{MakeComment .Description}}{{end}} {{- else }} // Returns: -{{- range .Op.OutputArg}} -// {{Identifier .Name}}{{if .Description}}: {{MakeComment .Description}}{{end}} +{{- range .OutArgs}} +// {{Identifier .RenameTo}}{{if .Description}}: {{MakeComment .Description}}{{end}} {{- end -}} {{- end -}} {{- end -}} @@ -247,15 +290,15 @@ func {{.Op.Name}} */ -}} (scope *Scope -{{- range $i, $a := .Op.InputArg}}, {{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}} -{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.Name}} {{GoType $a.Type}}{{end -}} +{{- range $i, $a := .InArgsReordered}}, {{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}} +{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.RenameTo}} {{GoType $a.Type}}{{end -}} {{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} ) -{{- /* Construct outputs: len(OpDef.OutputArg) or a *tf.Operation */ -}} +{{- /* Construct outputs: len(.OutArgs) or a *tf.Operation */ -}} -{{if .Op.OutputArg -}} -({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}) +{{if .OutArgs -}} +({{range $i,$a := .OutArgs}}{{if $i}}, {{end}}{{Identifier $a.RenameTo}} {{if $a.IsListArg}}[]{{end}}tf.Output{{end -}}) {{- else -}} (o *tf.Operation) {{- end }} { @@ -263,7 +306,7 @@ func {{.Op.Name}} return } {{if .HasAttrs -}} - attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}} + attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .RenameTo}},{{end}}} {{if .OptionalAttrs -}} for _, a := range optional { a(attrs) @@ -272,16 +315,16 @@ func {{.Op.Name}} {{end -}} opspec := tf.OpSpec{ Type: {{printf "%q" .Op.Name}}, - {{if .Op.InputArg -}} + {{if .InArgs -}} Input: []tf.Input{ - {{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}} + {{range $i,$a := .InArgs}}{{if $a.IsListArg}}tf.OutputList({{Identifier $a.RenameTo}}){{else}}{{Identifier $a.RenameTo}}{{end}}, {{end}} }, {{- end}} {{- if .HasAttrs}} Attrs: attrs, {{- end}} } - {{- if .Op.OutputArg}} + {{- if .OutArgs}} {{- if .HasListOutput}} op := scope.AddOperation(opspec) if scope.Err() != nil { @@ -289,43 +332,105 @@ func {{.Op.Name}} } var idx int var err error - {{- range $i, $a := .Op.OutputArg}} - {{- if IsListArg $a}} - if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { + {{- range $i, $a := .OutArgs}} + {{- if $a.IsListArg}} + if {{Identifier .RenameTo}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { scope.UpdateErr({{printf "%q" $.Op.Name}}, err) return } {{- else }} - {{Identifier .Name}} = op.Output(idx) + {{Identifier .RenameTo}} = op.Output(idx) {{- end }}{{- /* if IsListArg */}} - {{- end }}{{- /* range .Op.OutputArg */}} - return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier .Name}}{{end}} + {{- end }}{{- /* range .OutArgs */}} + return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}{{Identifier .RenameTo}}{{end}} {{- else }} op := scope.AddOperation(opspec) - return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} + return {{range $i, $a := .OutArgs}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} {{- end }}{{- /* if .HasListOutput */}} {{- else }} return scope.AddOperation(opspec) - {{- end }}{{- /* if .Op.OutputArg */}} + {{- end }}{{- /* if .OutArgs */}} } `)) ) +type attrWrapper struct { + op *pb.OpDef_AttrDef + api *pb.ApiDef_Attr +} + +func (a *attrWrapper) Name() string { return a.api.Name } +func (a *attrWrapper) RenameTo() string { return a.api.RenameTo } +func (a *attrWrapper) Description() string { return a.api.Description } +func (a *attrWrapper) Type() string { return a.op.Type } +func (a *attrWrapper) IsListAttr() bool { return isListAttr(a.op) } +func (a *attrWrapper) HasMinimum() bool { return a.op.HasMinimum } +func (a *attrWrapper) Minimum() int64 { return a.op.Minimum } +func (a *attrWrapper) DefaultValue() interface{} { return a.api.DefaultValue } + +type argWrapper struct { + op *pb.OpDef_ArgDef + api *pb.ApiDef_Arg +} + +func (a *argWrapper) Name() string { return a.api.Name } +func (a *argWrapper) RenameTo() string { return a.api.RenameTo } +func (a *argWrapper) Description() string { return a.api.Description } +func (a *argWrapper) IsListArg() bool { return isListArg(a.op) } + type tmplArgs struct { - Op *pb.OpDef + Op *pb.OpDef + APIDef *pb.ApiDef // Op.Attr is split into two categories // (1) Required: These must be specified by the client and are thus // included in the function signature. // (2) Optional: These need not be specified (as they have default // values) and thus do not appear in the function signature. - RequiredAttrs []*pb.OpDef_AttrDef - OptionalAttrs []*pb.OpDef_AttrDef + RequiredAttrs []*attrWrapper + OptionalAttrs []*attrWrapper + InArgs []*argWrapper + // Input arguments ordered based on arg_order field of ApiDef. + InArgsReordered []*argWrapper + OutArgs []*argWrapper } -func newTmplArgs(op *pb.OpDef) *tmplArgs { - ret := tmplArgs{Op: op} +func newTmplArgs(op *pb.OpDef, apidef *pb.ApiDef) (*tmplArgs, error) { + ret := tmplArgs{Op: op, APIDef: apidef} + + // Setup InArgs field + for i, in := range op.InputArg { + argCombined := argWrapper{op: in, api: apidef.InArg[i]} + ret.InArgs = append(ret.InArgs, &argCombined) + } + + // Setup OutArgs field + for i, out := range op.OutputArg { + argCombined := argWrapper{op: out, api: apidef.OutArg[i]} + ret.OutArgs = append(ret.OutArgs, &argCombined) + } + + // Setup InArgsReordered field + for _, argName := range apidef.ArgOrder { + // Find the argument in op.InputArg + argIndex := -1 + for i, in := range op.InputArg { + if in.Name == argName { + argIndex = i + break + } + } + if argIndex == -1 { + return nil, fmt.Errorf( + "couldn't find argument %s in ApiDef for op %s", + argName, op.Name) + } + argCombined := argWrapper{ + op: op.InputArg[argIndex], api: apidef.InArg[argIndex]} + ret.InArgsReordered = append(ret.InArgsReordered, &argCombined) + } + if len(op.Attr) == 0 { - return &ret + return &ret, nil } // Attributes related to the InputArg's type are inferred automatically // and are not exposed to the client. @@ -341,28 +446,29 @@ func newTmplArgs(op *pb.OpDef) *tmplArgs { inferred[in.NumberAttr] = true } } - for _, attr := range op.Attr { + for i, attr := range op.Attr { if inferred[attr.Name] { continue } + attrCombined := attrWrapper{op: attr, api: apidef.Attr[i]} if attr.DefaultValue == nil { - ret.RequiredAttrs = append(ret.RequiredAttrs, attr) + ret.RequiredAttrs = append(ret.RequiredAttrs, &attrCombined) } else { - ret.OptionalAttrs = append(ret.OptionalAttrs, attr) + ret.OptionalAttrs = append(ret.OptionalAttrs, &attrCombined) } } - return &ret + return &ret, nil } func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 } func (a *tmplArgs) DescribeArguments() bool { - for _, arg := range a.Op.InputArg { - if arg.Description != "" { + for _, arg := range a.InArgs { + if arg.Description() != "" { return true } } for _, attr := range a.RequiredAttrs { - if attr.Description != "" { + if attr.Description() != "" { return true } } @@ -370,16 +476,16 @@ func (a *tmplArgs) DescribeArguments() bool { } func (a *tmplArgs) DescribeOutputs() bool { - for _, arg := range a.Op.OutputArg { - if arg.Description != "" { + for _, arg := range a.OutArgs { + if arg.Description() != "" { return true } } return false } func (a *tmplArgs) HasListOutput() bool { - for _, arg := range a.Op.OutputArg { - if isListArg(arg) { + for _, arg := range a.OutArgs { + if arg.IsListArg() { return true } } diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go index c984c0063a9..b3a23dff102 100644 --- a/tensorflow/go/genop/internal/genop_test.go +++ b/tensorflow/go/genop/internal/genop_test.go @@ -25,19 +25,44 @@ import ( pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework" ) +// Creates an ApiDef based on opdef and applies overrides +// from apidefText (ApiDef text proto). +func GetAPIDef(t *testing.T, opdef *pb.OpDef, apidefText string) *pb.ApiDef { + opdefList := &pb.OpList{Op: []*pb.OpDef{opdef}} + apimap, err := newAPIDefMap(opdefList) + if err != nil { + t.Fatal(err) + } + err = apimap.Put(apidefText) + if err != nil { + t.Fatal(err) + } + apidef, err := apimap.Get(opdef.Name) + if err != nil { + t.Fatal(err) + } + return apidef +} + func TestGenerateOp(t *testing.T) { // TestGenerateOp validates the generated source code for an op. // The OpDef for the test cases are simplified forms of real ops. testdata := []struct { tag string opdef string + apidef string wanted string }{ { tag: "NoOp", opdef: ` name: "NoOp" +`, + apidef: ` +op: < +graph_op_name: "NoOp" summary: "No. Op." +> `, wanted: ` // No. Op. @@ -80,8 +105,13 @@ attr: < > > > +`, + apidef: ` +op: < +graph_op_name: "Add" summary: "Returns x + y element-wise." description: "Blah blah", +> `, wanted: ` // Returns x + y element-wise. @@ -122,7 +152,12 @@ attr: < name: "DstT" type: "type" > +`, + apidef: ` +op: < +graph_op_name: "Cast" summary: "Cast x of type SrcT to y of DstT." +> `, wanted: ` // Cast x of type SrcT to y of DstT. @@ -149,12 +184,10 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) { name: "DecodeJpeg" input_arg: < name: "contents" - description: "0-D. The JPEG-encoded image." type: DT_STRING > output_arg: < name: "image" - description: "3-D with shape [height, width, channels]" type: DT_UINT8 > attr: < @@ -163,7 +196,6 @@ attr: < default_value: < i: 0 > - description: "Number of color channels for the decoded image." > attr: < name: "fancy_upscaling" @@ -171,7 +203,6 @@ attr: < default_value: < b: true > - description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)." > attr: < name: "acceptable_fraction" @@ -179,10 +210,34 @@ attr: < default_value: < f: 1 > +> +`, + apidef: ` +op: < +graph_op_name: "DecodeJpeg" +in_arg: < + name: "contents" + description: "0-D. The JPEG-encoded image." +> +out_arg: < + name: "image" + description: "3-D with shape [height, width, channels]" +> +attr: < + name: "channels" + description: "Number of color channels for the decoded image." +> +attr: < + name: "fancy_upscaling" + description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)." +> +attr: < + name: "acceptable_fraction" description: "The minimum required fraction of lines before a truncated\ninput is accepted." > summary: "Decode a JPEG-encoded image to a uint8 tensor." description: "Norna dorna fjord\nkajorna\nhahaha" +> `, wanted: ` // DecodeJpegAttr is an optional argument to DecodeJpeg. @@ -270,7 +325,12 @@ attr: < name: "T" type: "type" > +`, + apidef: ` +op: < +graph_op_name: "TwoOutputs" summary: "Op that produces multiple outputs" +> `, wanted: ` // Op that produces multiple outputs @@ -326,8 +386,13 @@ attr: < > > > +`, + apidef: ` +op: < +graph_op_name: "ShapeN" summary: "Returns shape of tensors." description: "Some description here." +> `, wanted: ` // ShapeNAttr is an optional argument to ShapeN. @@ -371,6 +436,102 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t } return output } +`, + }, + { + tag: "ApiDefOverrides", + opdef: ` +name: "TestOp" +input_arg: < + name: "a" + type: DT_STRING +> +input_arg: < + name: "b" + type: DT_STRING +> +output_arg: < + name: "c" + type: DT_UINT8 +> +attr: < + name: "d" + type: "int" + default_value: < + i: 0 + > +> +`, + apidef: ` +op: < +graph_op_name: "TestOp" +in_arg: < + name: "a" + rename_to: "aa" + description: "Description for aa." +> +in_arg: < + name: "b" + rename_to: "bb" + description: "Description for bb." +> +arg_order: "b" +arg_order: "a" +out_arg: < + name: "c" + rename_to: "cc" + description: "Description for cc." +> +attr: < + name: "d" + rename_to: "dd" + description: "Description for dd." +> +summary: "Summary for TestOp." +description: "Description for TestOp." +> +`, + wanted: ` +// TestOpAttr is an optional argument to TestOp. +type TestOpAttr func(optionalAttr) + +// TestOpDd sets the optional dd attribute to value. +// +// value: Description for dd. +// If not specified, defaults to 0 +func TestOpDd(value int64) TestOpAttr { + return func(m optionalAttr) { + m["d"] = value + } +} + +// Summary for TestOp. +// +// Description for TestOp. +// +// Arguments: +// bb: Description for bb. +// aa: Description for aa. +// +// Returns Description for cc. +func TestOp(scope *Scope, bb tf.Output, aa tf.Output, optional ...TestOpAttr) (cc tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TestOp", + Input: []tf.Input{ + aa, bb, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} `, }, } @@ -378,11 +539,13 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t for _, test := range testdata { t.Run(test.tag, func(t *testing.T) { var opdef pb.OpDef + var apidef *pb.ApiDef var buf bytes.Buffer if err := proto.UnmarshalText(test.opdef, &opdef); err != nil { t.Fatal(err) } - if err := generateFunctionForOp(&buf, &opdef); err != nil { + apidef = GetAPIDef(t, &opdef, test.apidef) + if err := generateFunctionForOp(&buf, &opdef, apidef); err != nil { t.Fatal(err) } got, err := format.Source(buf.Bytes()) diff --git a/tensorflow/go/genop/main.go b/tensorflow/go/genop/main.go index 0c7d9be5c13..4a53084ed13 100644 --- a/tensorflow/go/genop/main.go +++ b/tensorflow/go/genop/main.go @@ -27,15 +27,17 @@ import ( "log" "os" "path/filepath" + "strings" "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal" ) func main() { var ( - filename = flag.String("outfile", "", "File to write generated source code to.") - header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty") - buf bytes.Buffer + filename = flag.String("outfile", "", "File to write generated source code to.") + header = flag.String("header", "", "Path to a file whose contents will be copied into the generated file. Can be empty") + apiDefDirs = flag.String("api_def_dirs", "", "Comma-separated directories containing api_def_*.pbtxt files.") + buf bytes.Buffer ) flag.Parse() if *filename == "" { @@ -51,7 +53,13 @@ func main() { } os.MkdirAll(filepath.Dir(*filename), 0755) - if err := internal.GenerateFunctionsForRegisteredOps(&buf); err != nil { + apiDefDirsList := []string{} + if len(*apiDefDirs) > 0 { + apiDefDirsList = strings.Split(*apiDefDirs, ",") + } + + if err := internal.GenerateFunctionsForRegisteredOps( + &buf, apiDefDirsList); err != nil { log.Fatal(err) } formatted, err := format.Source(buf.Bytes()) diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 789771508e2..5d20701d8fd 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -52,6 +52,12 @@ py_library( ]), ) +py_library( + name = "common", + srcs = ["lib/common.py"], + srcs_version = "PY2AND3", +) + py_library( name = "debug_graphs", srcs = ["lib/debug_graphs.py"], @@ -117,6 +123,7 @@ py_library( srcs = ["lib/source_remote.py"], srcs_version = "PY2AND3", deps = [ + ":common", ":debug_service_pb2_grpc", "//tensorflow/core/debug:debug_service_proto_py", "//tensorflow/python/profiler:tfprof_logger", @@ -193,6 +200,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":command_parser", + ":common", ":debugger_cli_common", ":tensor_format", "//tensorflow/python:framework_for_generated_wrappers", @@ -334,7 +342,11 @@ py_library( name = "grpc_wrapper", srcs = ["wrappers/grpc_wrapper.py"], srcs_version = "PY2AND3", - deps = [":framework"], + deps = [ + ":common", + ":framework", + ":source_remote", + ], ) py_library( @@ -345,6 +357,7 @@ py_library( ":analyzer_cli", ":cli_shared", ":command_parser", + ":common", ":debug_data", ":debugger_cli_common", ":framework", @@ -439,6 +452,20 @@ py_binary( ], ) +py_test( + name = "common_test", + size = "small", + srcs = ["lib/common_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:platform_test", + ], +) + py_test( name = "debug_graphs_test", size = "small", diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index 847f9ec4014..a366f06c4bb 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -55,7 +55,8 @@ def no_rewrite_session_config(): rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, - arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) return config_pb2.ConfigProto(graph_options=graph_options) diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py index df972eacf73..24431742eb0 100644 --- a/tensorflow/python/debug/cli/cli_shared.py +++ b/tensorflow/python/debug/cli/cli_shared.py @@ -25,6 +25,7 @@ import six from tensorflow.python.debug.cli import command_parser from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import tensor_format +from tensorflow.python.debug.lib import common from tensorflow.python.framework import ops from tensorflow.python.ops import variables @@ -214,51 +215,6 @@ def error(msg): RL("ERROR: " + msg, COLOR_RED)]) -def get_graph_element_name(elem): - """Obtain the name or string representation of a graph element. - - If the graph element has the attribute "name", return name. Otherwise, return - a __str__ representation of the graph element. Certain graph elements, such as - `SparseTensor`s, do not have the attribute "name". - - Args: - elem: The graph element in question. - - Returns: - If the attribute 'name' is available, return the name. Otherwise, return - str(fetch). - """ - - return elem.name if hasattr(elem, "name") else str(elem) - - -def _get_fetch_names(fetches): - """Get a flattened list of the names in run() call fetches. - - Args: - fetches: Fetches of the `Session.run()` call. It maybe a Tensor, an - Operation or a Variable. It may also be nested lists, tuples or - dicts. See doc of `Session.run()` for more details. - - Returns: - (list of str) A flattened list of fetch names from `fetches`. - """ - - lines = [] - if isinstance(fetches, (list, tuple)): - for fetch in fetches: - lines.extend(_get_fetch_names(fetch)) - elif isinstance(fetches, dict): - for key in fetches: - lines.extend(_get_fetch_names(fetches[key])) - else: - # This ought to be a Tensor, an Operation or a Variable, for which the name - # attribute should be available. (Bottom-out condition of the recursion.) - lines.append(get_graph_element_name(fetches)) - - return lines - - def _recommend_command(command, description, indent=2, create_link=False): """Generate a RichTextLines object that describes a recommended command. @@ -327,14 +283,14 @@ def get_run_start_intro(run_call_count, (RichTextLines) Formatted intro message about the `Session.run()` call. """ - fetch_lines = _get_fetch_names(fetches) + fetch_lines = common.get_flattened_names(fetches) if not feed_dict: feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] else: feed_dict_lines = [] for feed_key in feed_dict: - feed_key_name = get_graph_element_name(feed_key) + feed_key_name = common.get_graph_element_name(feed_key) feed_dict_line = debugger_cli_common.RichLine(" ") feed_dict_line += debugger_cli_common.RichLine( feed_key_name, @@ -446,10 +402,10 @@ def get_run_short_description(run_call_count, description = "run #%d: " % run_call_count if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): - description += "1 fetch (%s); " % get_graph_element_name(fetches) + description += "1 fetch (%s); " % common.get_graph_element_name(fetches) else: # Could be (nested) list, tuple, dict or namedtuple. - num_fetches = len(_get_fetch_names(fetches)) + num_fetches = len(common.get_flattened_names(fetches)) if num_fetches > 1: description += "%d fetches; " % num_fetches else: diff --git a/tensorflow/python/debug/lib/common.py b/tensorflow/python/debug/lib/common.py new file mode 100644 index 00000000000..19a0d8c5010 --- /dev/null +++ b/tensorflow/python/debug/lib/common.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common values and methods for TensorFlow Debugger.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import json + +GRPC_URL_PREFIX = "grpc://" + +# A key for a Session.run() call. +RunKey = collections.namedtuple("RunKey", ["feed_names", "fetch_names"]) + + +def get_graph_element_name(elem): + """Obtain the name or string representation of a graph element. + + If the graph element has the attribute "name", return name. Otherwise, return + a __str__ representation of the graph element. Certain graph elements, such as + `SparseTensor`s, do not have the attribute "name". + + Args: + elem: The graph element in question. + + Returns: + If the attribute 'name' is available, return the name. Otherwise, return + str(fetch). + """ + + return elem.name if hasattr(elem, "name") else str(elem) + + +def get_flattened_names(feeds_or_fetches): + """Get a flattened list of the names in run() call feeds or fetches. + + Args: + feeds_or_fetches: Feeds or fetches of the `Session.run()` call. It maybe + a Tensor, an Operation or a Variable. It may also be nested lists, tuples + or dicts. See doc of `Session.run()` for more details. + + Returns: + (list of str) A flattened list of fetch names from `feeds_or_fetches`. + """ + + lines = [] + if isinstance(feeds_or_fetches, (list, tuple)): + for item in feeds_or_fetches: + lines.extend(get_flattened_names(item)) + elif isinstance(feeds_or_fetches, dict): + for key in feeds_or_fetches: + lines.extend(get_flattened_names(feeds_or_fetches[key])) + else: + # This ought to be a Tensor, an Operation or a Variable, for which the name + # attribute should be available. (Bottom-out condition of the recursion.) + lines.append(get_graph_element_name(feeds_or_fetches)) + + return lines + + +def get_run_key(feed_dict, fetches): + """Summarize the names of feeds and fetches as a RunKey JSON string. + + Args: + feed_dict: The feed_dict given to the `Session.run()` call. + fetches: The fetches from the `Session.run()` call. + + Returns: + A JSON Array consisting of two items. They first items is a flattened + Array of the names of the feeds. The second item is a flattened Array of + the names of the fetches. + """ + return json.dumps(RunKey(get_flattened_names(feed_dict), + get_flattened_names(fetches))) diff --git a/tensorflow/python/debug/lib/common_test.py b/tensorflow/python/debug/lib/common_test.py new file mode 100644 index 00000000000..5af0dafcf9f --- /dev/null +++ b/tensorflow/python/debug/lib/common_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for common values and methods of TensorFlow Debugger.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.python.debug.lib import common +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class CommonTest(test_util.TensorFlowTestCase): + + def testOnFeedOneFetch(self): + a = constant_op.constant(10.0, name="a") + b = constant_op.constant(20.0, name="b") + run_key = common.get_run_key({"a": a}, [b]) + loaded = json.loads(run_key) + self.assertItemsEqual(["a:0"], loaded[0]) + self.assertItemsEqual(["b:0"], loaded[1]) + + def testGetRunKeyFlat(self): + a = constant_op.constant(10.0, name="a") + b = constant_op.constant(20.0, name="b") + run_key = common.get_run_key({"a": a}, [a, b]) + loaded = json.loads(run_key) + self.assertItemsEqual(["a:0"], loaded[0]) + self.assertItemsEqual(["a:0", "b:0"], loaded[1]) + + def testGetRunKeyNestedFetches(self): + a = constant_op.constant(10.0, name="a") + b = constant_op.constant(20.0, name="b") + c = constant_op.constant(30.0, name="c") + d = constant_op.constant(30.0, name="d") + run_key = common.get_run_key( + {}, {"set1": [a, b], "set2": {"c": c, "d": d}}) + loaded = json.loads(run_key) + self.assertItemsEqual([], loaded[0]) + self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1]) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py index a637677d7d0..91700469484 100644 --- a/tensorflow/python/debug/lib/grpc_debug_test_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py @@ -310,7 +310,7 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): op_log_proto.id_to_string) raise ValueError( "Op '%s' does not exist in the tracebacks received by the debug " - "server.") + "server." % op_name) def query_origin_stack(self): """Query the stack of the origin of the execution call. @@ -348,6 +348,9 @@ class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): Raises: ValueError: If no source file is found at the given file_path. """ + if not self._source_files: + raise ValueError( + "This debug server has not received any source file contents yet.") for source_file_proto in self._source_files.source_files: if source_file_proto.file_path == file_path: return source_file_proto.lines[lineno - 1] diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index 99781bd9d90..068e4f81c0d 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -248,7 +248,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): self.assertEqual( 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0])) - def testTensorBoardDebugHooWorks(self): + def testTensorBoardDebugHookWorks(self): u = variables.Variable(2.1, name="u") v = variables.Variable(20.0, name="v") w = math_ops.multiply(u, v, name="w") @@ -261,8 +261,37 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): ["localhost:%d" % self._server_port]) sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) + # Activate watch point on some a tensor before calling sess.run(). + self._server.request_watch("u/read", 0, "DebugIdentity") self.assertAllClose(42.0, sess.run(w)) + # self.assertAllClose(42.0, sess.run(w)) + dump = debug_data.DebugDumpDir(self._dump_root) + self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity")) + + # Check that the server has received the stack trace. + self.assertTrue(self._server.query_op_traceback("u")) + self.assertTrue(self._server.query_op_traceback("u/read")) + self.assertTrue(self._server.query_op_traceback("v")) + self.assertTrue(self._server.query_op_traceback("v/read")) + self.assertTrue(self._server.query_op_traceback("w")) + + # Check that the server has received the python file content. + # Query an arbitrary line to make sure that is the case. + with open(__file__, "rt") as this_source_file: + first_line = this_source_file.readline().strip() + self.assertEqual( + first_line, self._server.query_source_file_line(__file__, 1)) + + self._server.clear_data() + # Call sess.run() again, and verify that this time the traceback and source + # code is not sent, because the graph version is not newer. + self.assertAllClose(42.0, sess.run(w)) + with self.assertRaises(ValueError): + self._server.query_op_traceback("delta_1") + with self.assertRaises(ValueError): + self._server.query_source_file_line(__file__, 1) + def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self): hooks.GrpcDebugHook(["grpc://foo:42424"]) hooks.GrpcDebugHook(["foo:42424"]) @@ -748,6 +777,28 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): # to disable the breakpoint at delta:0:DebugIdentity. self.assertSetEqual(set(), self._server_1.breakpoints) + if i == 0: + # Check that the server has received the stack trace. + self.assertTrue(self._server_1.query_op_traceback("delta_1")) + self.assertTrue(self._server_1.query_op_traceback("delta_2")) + self.assertTrue(self._server_1.query_op_traceback("inc_v_1")) + self.assertTrue(self._server_1.query_op_traceback("inc_v_2")) + # Check that the server has received the python file content. + # Query an arbitrary line to make sure that is the case. + with open(__file__, "rt") as this_source_file: + first_line = this_source_file.readline().strip() + self.assertEqual( + first_line, self._server_1.query_source_file_line(__file__, 1)) + else: + # In later Session.run() calls, the traceback shouldn't have been sent + # because it is already sent in the 1st call. So calling + # query_op_traceback() should lead to an exception, because the test + # debug server clears the data at the beginning of every iteration. + with self.assertRaises(ValueError): + self._server_1.query_op_traceback("delta_1") + with self.assertRaises(ValueError): + self._server_1.query_source_file_line(__file__, 1) + def testGetGrpcDebugWatchesReturnsCorrectAnswer(self): with session.Session() as sess: v = variables.Variable(50.0, name="v") diff --git a/tensorflow/python/debug/lib/source_remote.py b/tensorflow/python/debug/lib/source_remote.py index 9d10d5a8d11..7fd8ceca1dd 100644 --- a/tensorflow/python/debug/lib/source_remote.py +++ b/tensorflow/python/debug/lib/source_remote.py @@ -24,6 +24,7 @@ import grpc from tensorflow.core.debug import debug_service_pb2 from tensorflow.core.protobuf import debug_pb2 +from tensorflow.python.debug.lib import common from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.debug.lib import source_utils from tensorflow.python.platform import gfile @@ -130,6 +131,11 @@ def _send_call_tracebacks(destinations, """ if not isinstance(destinations, list): destinations = [destinations] + # Strip grpc:// prefix, if any is present. + destinations = [ + dest[len(common.GRPC_URL_PREFIX):] + if dest.startswith(common.GRPC_URL_PREFIX) else dest + for dest in destinations] call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION if is_eager_execution diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 16b2018b413..cb9bf95782c 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -17,15 +17,55 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys +import traceback + # Google-internal import(s). +from tensorflow.python.debug.lib import common from tensorflow.python.debug.wrappers import framework +def publish_traceback(debug_server_urls, + graph, + feed_dict, + fetches, + old_graph_version): + """Publish traceback and source code if graph version is new. + + `graph.version` is compared with `old_graph_version`. If the former is higher + (i.e., newer), the graph traceback and the associated source code is sent to + the debug server at the specified gRPC URLs. + + Args: + debug_server_urls: A single gRPC debug server URL as a `str` or a `list` of + debug server URLs. + graph: A Python `tf.Graph` object. + feed_dict: Feed dictionary given to the `Session.run()` call. + fetches: Fetches from the `Session.run()` call. + old_graph_version: Old graph version to compare to. + + Returns: + If `graph.version > old_graph_version`, the new graph version as an `int`. + Else, the `old_graph_version` is returned. + """ + # TODO(cais): Consider moving this back to the top, after grpc becomes a + # pip dependency of tensorflow or tf_debug. + # pylint:disable=g-import-not-at-top + from tensorflow.python.debug.lib import source_remote + # pylint:enable=g-import-not-at-top + if graph.version > old_graph_version: + run_key = common.get_run_key(feed_dict, fetches) + source_remote.send_graph_tracebacks( + debug_server_urls, run_key, traceback.extract_stack(), graph, + send_source=True) + return graph.version + else: + return old_graph_version + + class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): """Debug Session wrapper that send debug data to gRPC stream(s).""" - _GRPC_URL_PREFIX = "grpc://" - def __init__(self, sess, grpc_debug_server_addresses, @@ -94,8 +134,8 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): return self._grpc_debug_server_urls def _normalize_grpc_url(self, address): - return (self._GRPC_URL_PREFIX + address - if not address.startswith(self._GRPC_URL_PREFIX) else address) + return (common.GRPC_URL_PREFIX + address + if not address.startswith(common.GRPC_URL_PREFIX) else address) class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): @@ -126,3 +166,25 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): watch_fn=_gated_grpc_watch_fn, thread_name_filter=thread_name_filter, log_usage=log_usage) + + # Keeps track of the latest version of Python graph object that has been + # sent to the debug servers. + self._sent_graph_version = -sys.maxint + + def run(self, + fetches, + feed_dict=None, + options=None, + run_metadata=None, + callable_runner=None, + callable_runner_args=None): + self._sent_graph_version = publish_traceback( + self._grpc_debug_server_urls, self.graph, feed_dict, fetches, + self._sent_graph_version) + return super(TensorBoardDebugWrapperSession, self).run( + fetches, + feed_dict=feed_dict, + options=options, + run_metadata=run_metadata, + callable_runner=callable_runner, + callable_runner_args=callable_runner_args) diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 43066996248..aa9f0650406 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + from tensorflow.core.protobuf import config_pb2 from tensorflow.python.debug.lib import debug_utils from tensorflow.python.debug.lib import stepper @@ -331,3 +333,13 @@ class TensorBoardDebugHook(GrpcDebugHook): watch_fn=_gated_grpc_watch_fn, thread_name_filter=thread_name_filter, log_usage=log_usage) + + self._grpc_debug_server_addresses = grpc_debug_server_addresses + self._sent_graph_version = -sys.maxint + + def before_run(self, run_context): + self._sent_graph_version = grpc_wrapper.publish_traceback( + self._grpc_debug_server_addresses, run_context.session.graph, + run_context.original_args.feed_dict, run_context.original_args.fetches, + self._sent_graph_version) + return super(TensorBoardDebugHook, self).before_run(run_context) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 5bf6d9d1f4a..c46a4e7d1aa 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -31,6 +31,7 @@ from tensorflow.python.debug.cli import debugger_cli_common from tensorflow.python.debug.cli import profile_analyzer_cli from tensorflow.python.debug.cli import stepper_cli from tensorflow.python.debug.cli import ui_factory +from tensorflow.python.debug.lib import common from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import framework @@ -464,7 +465,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): feed_key = None feed_value = None for key in self._feed_dict: - key_name = cli_shared.get_graph_element_name(key) + key_name = common.get_graph_element_name(key) if key_name == tensor_name: feed_key = key_name feed_value = self._feed_dict[key] @@ -561,7 +562,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): list(self._tensor_filters.keys())) if self._feed_dict: # Register tab completion for feed_dict keys. - feed_keys = [cli_shared.get_graph_element_name(key) + feed_keys = [common.get_graph_element_name(key) for key in self._feed_dict.keys()] curses_cli.register_tab_comp_context(["print_feed", "pf"], feed_keys) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index d94a7acd09c..1a1d6d8d053 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -495,14 +495,13 @@ class GraphModeFunction(object): def _get_defun_inputs(args): """Maps the inputs args to graph inputs.""" ret = [] - for a in args: + flat_args = nest.flatten(args) + for a in flat_args: if isinstance(a, ops.Tensor): ret.append(graph_placeholder(a.dtype, a.shape)) - elif type(a) in (tuple, list): - ret.append(_get_defun_inputs(a)) else: ret.append(a) - return tuple(ret) if type(args) is tuple else ret + return nest.pack_sequence_as(args, ret) def _defun_internal(name, func, args, kwds): @@ -582,8 +581,10 @@ def _cache_key(x): return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access if isinstance(x, np.ndarray): return ("array", x.shape, tuple(x.reshape(-1))) - if type(x) in (list, tuple): + if isinstance(x, (list, tuple)): return tuple([_cache_key(a) for a in x]) + if isinstance(x, dict): + return tuple(tuple([_cache_key(k), _cache_key(v)]) for k, v in x.items()) return x diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index e3ea35a6400..0a23254d40e 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function @@ -57,6 +59,20 @@ class FunctionTest(test.TestCase): out = sq(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testNestedInputsGraphMode(self): + matmul = function.defun(math_ops.matmul) + + pair = collections.namedtuple('pair', ['a', 'b']) + + @function.defun + def a_times_b(inputs): + return matmul(inputs.a['a'], inputs.b['b']) + + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + + out = a_times_b(pair({'a': t}, {'b': t})) + self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testGraphModeWithGradients(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') @@ -83,6 +99,22 @@ class FunctionTest(test.TestCase): out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testNestedInputsDefunOpGraphMode(self): + matmul = function.defun(math_ops.matmul) + + pair = collections.namedtuple('pair', ['a', 'b']) + def a_times_b(inputs): + return matmul(inputs.a['a'], inputs.b['b']) + + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + + inputs = pair({'a': t}, {'b': t}) + sq_op = function.make_defun_op(a_times_b, inputs) + + self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) + out = sq_op(inputs) + self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testNestedOutputDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 4902978e2dc..7f07deebef3 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/python/lib/core/bfloat16.h" #include "tensorflow/core/framework/numeric_types.h" @@ -477,8 +479,61 @@ bool RegisterBfloat16Cast(int numpy_type, bool cast_is_safe) { return true; } +template +void BinaryUFunc(char** args, npy_intp* dimensions, npy_intp* steps, + void* data) { + const char* i0 = args[0]; + const char* i1 = args[1]; + char* o = args[2]; + for (npy_intp k = 0; k < *dimensions; k++) { + InType x = *reinterpret_cast(i0); + InType y = *reinterpret_cast(i1); + *reinterpret_cast(o) = Functor()(x, y); + i0 += steps[0]; + i1 += steps[1]; + o += steps[2]; + } +} + +template +void CompareUFunc(char** args, npy_intp* dimensions, npy_intp* steps, + void* data) { + BinaryUFunc(args, dimensions, steps, data); +} + +struct Bfloat16EqFunctor { + npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; } +}; +struct Bfloat16NeFunctor { + npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; } +}; +struct Bfloat16LtFunctor { + npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; } +}; +struct Bfloat16GtFunctor { + npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; } +}; +struct Bfloat16LeFunctor { + npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; } +}; +struct Bfloat16GeFunctor { + npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; } +}; + // Initializes the module. bool Initialize() { + // It's critical to import umath to avoid crash in open source build. + import_umath1(false); + + Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy")); + if (!numpy_str) { + return false; + } + Safe_PyObjectPtr numpy = make_safe(PyImport_Import(numpy_str.get())); + if (!numpy) { + return false; + } + // We hit a mysterious crash if we haven't initialized numpy before this: PyBfloat16_Type.tp_base = &PyGenericArrType_Type; @@ -536,6 +591,57 @@ bool Initialize() { /*cast_is_safe=*/true)) { return false; } + + // Register ufuncs + auto register_ufunc = [&](const char* name, PyUFuncGenericFunction fn, + const std::array& types) { + Safe_PyObjectPtr ufunc_obj = + make_safe(PyObject_GetAttrString(numpy.get(), name)); + if (!ufunc_obj) { + return false; + } + PyUFuncObject* ufunc = reinterpret_cast(ufunc_obj.get()); + if (types.size() != ufunc->nargs) { + PyErr_Format(PyExc_AssertionError, + "ufunc %s takes %d arguments, loop takes %lu", name, + ufunc->nargs, types.size()); + return false; + } + if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16_, fn, + const_cast(types.data()), + nullptr) < 0) { + return false; + } + return true; + }; + + // Comparisons + const std::array compare_types = {npy_bfloat16_, npy_bfloat16_, + NPY_BOOL}; + + if (!register_ufunc("equal", CompareUFunc, + compare_types)) { + return false; + } + if (!register_ufunc("not_equal", CompareUFunc, + compare_types)) { + return false; + } + if (!register_ufunc("less", CompareUFunc, compare_types)) { + return false; + } + if (!register_ufunc("greater", CompareUFunc, + compare_types)) { + return false; + } + if (!register_ufunc("less_equal", CompareUFunc, + compare_types)) { + return false; + } + if (!register_ufunc("greater_equal", CompareUFunc, + compare_types)) { + return false; + } return true; } diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 0872348c51b..985a11272c8 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -172,6 +172,24 @@ class Bfloat16NumPyTest(test.TestCase): self.assertEqual("[[bfloat16(1) bfloat16(2) bfloat16(3)]]", str(x)) self.assertAllEqual(x, x) self.assertAllClose(x, x) + self.assertTrue((x == x).all()) + + def testComparisons(self): + x = np.array([401408, 7, -32], dtype=np.float32) + bx = x.astype(bfloat16) + y = np.array([82432, 7, 0], dtype=np.float32) + by = y.astype(bfloat16) + self.assertAllEqual(x == y, bx == by) + self.assertAllEqual(x != y, bx != by) + self.assertAllEqual(x < y, bx < by) + self.assertAllEqual(x > y, bx > by) + self.assertAllEqual(x <= y, bx <= by) + self.assertAllEqual(x >= y, bx >= by) + + def testEqual2(self): + a = np.array([401408], bfloat16) + b = np.array([82432], bfloat16) + self.assertFalse(a.__eq__(b)) def testCasts(self): for dtype in [ diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h index 0eafe890dba..25322b458b8 100644 --- a/tensorflow/python/lib/core/numpy.h +++ b/tensorflow/python/lib/core/numpy.h @@ -32,6 +32,7 @@ limitations under the License. #include #include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" namespace tensorflow { diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 39c64398116..e34aa7cc2ca 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -25,6 +25,7 @@ py_library( ":main_op", ":signature_constants", ":signature_def_utils", + ":simple_save", ":tag_constants", ":utils", "//tensorflow/python:util", @@ -89,6 +90,23 @@ py_library( ], ) +py_library( + name = "simple_save", + srcs = [ + "simple_save.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":builder", + ":signature_constants", + ":signature_def_utils", + ":tag_constants", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:lib", + "//tensorflow/python:util", + ], +) + py_library( name = "main_op", srcs = [ @@ -198,6 +216,22 @@ py_test( ], ) +py_test( + name = "simple_save_test", + size = "small", + srcs = ["simple_save_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":loader", + ":signature_constants", + ":simple_save", + ":tag_constants", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:variables", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py index 8c59f7afe77..caabd7bc304 100644 --- a/tensorflow/python/saved_model/saved_model.py +++ b/tensorflow/python/saved_model/saved_model.py @@ -30,6 +30,9 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils # pylint: enable=unused-import +# pylint: disable=wildcard-import +from tensorflow.python.saved_model.simple_save import * +# pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -41,6 +44,7 @@ _allowed_symbols = [ "main_op", "signature_constants", "signature_def_utils", + "simple_save", "tag_constants", "utils", ] diff --git a/tensorflow/contrib/saved_model/python/saved_model/utils.py b/tensorflow/python/saved_model/simple_save.py similarity index 97% rename from tensorflow/contrib/saved_model/python/saved_model/utils.py rename to tensorflow/python/saved_model/simple_save.py index 9f34af64a62..9a81e5cd807 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/utils.py +++ b/tensorflow/python/saved_model/simple_save.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SavedModel utility functions.""" +"""SavedModel simple save functionality.""" from __future__ import absolute_import from __future__ import division @@ -39,7 +39,7 @@ def simple_save(session, export_dir, inputs, outputs, legacy_init_op=None): to configure a SavedModel, this method has a few practical implications: - It will be treated as a graph for inference / serving (i.e. uses the tag `tag_constants.SERVING`) - - The saved model will load in TensorFlow Serving and supports the + - The SavedModel will load in TensorFlow Serving and supports the [Predict API](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/predict.proto). To use the Classify, Regress, or MultiInference APIs, please use either diff --git a/tensorflow/contrib/saved_model/python/saved_model/utils_test.py b/tensorflow/python/saved_model/simple_save_test.py similarity index 95% rename from tensorflow/contrib/saved_model/python/saved_model/utils_test.py rename to tensorflow/python/saved_model/simple_save_test.py index 36dfb88871f..b2fa40d4f13 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/utils_test.py +++ b/tensorflow/python/saved_model/simple_save_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for saved_model utils.""" +"""Tests for SavedModel simple save functionality.""" from __future__ import absolute_import from __future__ import division @@ -20,16 +20,16 @@ from __future__ import print_function import os -from tensorflow.contrib.saved_model.python.saved_model import utils from tensorflow.python.framework import ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import simple_save from tensorflow.python.saved_model import tag_constants -class UtilsTest(test.TestCase): +class SimpleSaveTest(test.TestCase): def _init_and_validate_variable(self, sess, variable_name, variable_value): v = variables.Variable(variable_value, name=variable_name) @@ -65,7 +65,7 @@ class UtilsTest(test.TestCase): var_y = self._init_and_validate_variable(sess, "var_y", 2) inputs = {"x": var_x} outputs = {"y": var_y} - utils.simple_save(sess, export_dir, inputs, outputs) + simple_save.simple_save(sess, export_dir, inputs, outputs) # Restore the graph with a valid tag and check the global variables and # signature def map. diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt index 5683766b289..e1a0385092c 100644 --- a/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.saved_model.pbtxt @@ -32,4 +32,8 @@ tf_module { name: "utils" mtype: "" } + member_method { + name: "simple_save" + argspec: "args=[\'session\', \'export_dir\', \'inputs\', \'outputs\', \'legacy_init_op\'], varargs=None, keywords=None, defaults=[\'None\'], " + } }