From 8d9eda26be345ace2e110feb0cf9a2500990eb82 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 22 Nov 2017 15:51:33 -0800 Subject: [PATCH 01/15] Remove hardcoded discovery document now that the TPU alpha API definitions are public (https://www.googleapis.com/discovery/v1/apis/tpu/v1alpha1/rest). PiperOrigin-RevId: 176710985 --- .../python/training/tpu_cluster_resolver.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index f0144e9faa2..c74da9cabd6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -80,13 +80,9 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') - # TODO(b/67375680): Remove custom URL once TPU APIs are finalized self._service = discovery.build( - 'tpu', - 'v1', - credentials=self._credentials, - discoveryServiceUrl='https://storage.googleapis.com' - '/tpu-api-definition/v1alpha1.json') + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service From 0aa09c2de25ddb321405656ae33031773690bd5e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2017 15:54:27 -0800 Subject: [PATCH 02/15] dynamic_rnn: put all ops in the same scope This clarifies the graph visualization a bit. PiperOrigin-RevId: 176711260 --- tensorflow/python/ops/rnn.py | 47 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 436872f0447..e30b19842f0 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -565,33 +565,34 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, if not _like_rnncell(cell): raise TypeError("cell must be an instance of RNNCell") - # By default, time_major==False and inputs are batch-major: shaped - # [batch, time, depth] - # For internal calculations, we transpose to [time, batch, depth] - flat_input = nest.flatten(inputs) - - if not time_major: - # (B,T,D) => (T,B,D) - flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] - flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) - - parallel_iterations = parallel_iterations or 32 - if sequence_length is not None: - sequence_length = math_ops.to_int32(sequence_length) - if sequence_length.get_shape().ndims not in (None, 1): - raise ValueError( - "sequence_length must be a vector of length batch_size, " - "but saw shape: %s" % sequence_length.get_shape()) - sequence_length = array_ops.identity( # Just to find it in the graph. - sequence_length, name="sequence_length") - - # Create a new scope in which the caching device is either - # determined by the parent scope, or is set to place the cached - # Variable using the same placement as for the rest of the RNN. with vs.variable_scope(scope or "rnn") as varscope: + # Create a new scope in which the caching device is either + # determined by the parent scope, or is set to place the cached + # Variable using the same placement as for the rest of the RNN. if context.in_graph_mode(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) + + # By default, time_major==False and inputs are batch-major: shaped + # [batch, time, depth] + # For internal calculations, we transpose to [time, batch, depth] + flat_input = nest.flatten(inputs) + + if not time_major: + # (B,T,D) => (T,B,D) + flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] + flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) + + parallel_iterations = parallel_iterations or 32 + if sequence_length is not None: + sequence_length = math_ops.to_int32(sequence_length) + if sequence_length.get_shape().ndims not in (None, 1): + raise ValueError( + "sequence_length must be a vector of length batch_size, " + "but saw shape: %s" % sequence_length.get_shape()) + sequence_length = array_ops.identity( # Just to find it in the graph. + sequence_length, name="sequence_length") + batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: From 806754888188e40430bc96ad33c5f51282c2d338 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 22 Nov 2017 16:30:44 -0800 Subject: [PATCH 03/15] Acquire the GIL before working with PyLists and PyDict. PiperOrigin-RevId: 176714705 --- tensorflow/python/grappler/cluster.i | 9 ++++++++- tensorflow/python/grappler/item.i | 5 +++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 1838c40e463..5a7cdf26f8e 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -138,6 +138,7 @@ tensorflow::Status _GetOpPerformanceDataAndRunTime( static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) { const std::unordered_map& devices = cluster->GetDevices(); + PyGILState_STATE gstate = PyGILState_Ensure(); PyObject* result = PyList_New(devices.size()); int i = 0; for (auto& dev : devices) { @@ -150,6 +151,7 @@ static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) { PyList_SetItem(result, i, dev_obj); ++i; } + PyGILState_Release(gstate); return result; } @@ -184,6 +186,7 @@ static PyObject* TF_MeasureCosts( if (!status.ok()) { Py_RETURN_NONE; } + PyGILState_STATE gstate = PyGILState_Ensure(); PyObject* op_perf_objs = PyList_New( op_performance_data.op_performance_size()); for (int i = 0; i < op_performance_data.op_performance_size(); i++) { @@ -211,8 +214,10 @@ static PyObject* TF_MeasureCosts( status = tensorflow::Status(tensorflow::error::Code::INTERNAL, "Error setting return tuples."); tensorflow::Set_TF_Status_from_Status(out_status, status); - Py_RETURN_NONE; + Py_INCREF(Py_None); + ret = Py_None; } + PyGILState_Release(gstate); return ret; } @@ -240,6 +245,7 @@ static PyObject* TF_DeterminePeakMemoryUsage( Py_RETURN_NONE; } + PyGILState_STATE gstate = PyGILState_Ensure(); PyObject* result = PyDict_New(); for (const auto& device : cluster->GetDevices()) { const tensorflow::grappler::GraphMemory::MemoryUsage& usage = @@ -261,6 +267,7 @@ static PyObject* TF_DeterminePeakMemoryUsage( PyTuple_SetItem(ret, 1, per_device); PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret); } + PyGILState_Release(gstate); return result; } diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index 8c346b44387..2fa502b81da 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -101,6 +101,7 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it Py_RETURN_NONE; } + PyGILState_STATE gstate = PyGILState_Ensure(); PyObject* props = PyDict_New(); for (const auto& node : item->graph.node()) { const string& node_name = node.name(); @@ -115,8 +116,8 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it PyList_SetItem(prop, i, output_prop); } CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop)); - } - + } + PyGILState_Release(gstate); return props; } From 85fa6bdfe40f24259b3cec19637567ed3cff7370 Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Wed, 22 Nov 2017 16:35:13 -0800 Subject: [PATCH 04/15] [tf.data] Patch for thread safe IgnoreErrorDataset. PiperOrigin-RevId: 176715082 --- .../core/kernels/ignore_errors_dataset_op.cc | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/ignore_errors_dataset_op.cc index 43ba5ab7dd0..8cf263d87fe 100644 --- a/tensorflow/core/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/core/kernels/ignore_errors_dataset_op.cc @@ -79,16 +79,20 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { - if (!input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } - Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); - while (!s.ok()) { - out_tensors->clear(); - s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + { + tf_shared_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + while (!s.ok()) { + out_tensors->clear(); + s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } } if (*end_of_sequence) { + mutex_lock l(mu_); input_impl_.reset(); } return Status::OK(); @@ -96,6 +100,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); if (input_impl_) TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); else @@ -106,6 +111,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { Status RestoreInternal(OpKernelContext* ctx, IteratorStateReader* reader) override { + mutex_lock l(mu_); if (reader->Contains(full_name("input_impls_empty"))) input_impl_.reset(); else @@ -114,7 +120,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { } private: - std::unique_ptr input_impl_; + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); }; const DatasetBase* const input_; From 1885db7ffa6cea7bacfb7ef1507f3103cd1829f0 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 22 Nov 2017 16:36:48 -0800 Subject: [PATCH 05/15] Only convert the layout if the node is placed on GPU. PiperOrigin-RevId: 176715219 --- tensorflow/core/grappler/optimizers/BUILD | 2 + .../grappler/optimizers/layout_optimizer.cc | 146 +++++++++++------- .../grappler/optimizers/layout_optimizer.h | 5 +- .../optimizers/layout_optimizer_test.cc | 67 +++++++- 4 files changed, 163 insertions(+), 57 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index e1275560545..5d9eb8e0b12 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -312,6 +312,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_optimizer", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -320,6 +321,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/costs:virtual_placer", "//tensorflow/core/grappler/utils:frame", ], ) diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 89ebd8e98fc..31c3ba68637 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -27,7 +27,9 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace grappler { @@ -109,11 +111,13 @@ bool IsMaxPoolGradV1(const NodeDef& node) { class GraphProcessor { public: - GraphProcessor(GraphDef* graph, NodeMap* node_map, - const std::unordered_set& nodes_to_preserve) - : graph_(graph), - node_map_(node_map), - nodes_to_preserve_(nodes_to_preserve) {} + GraphProcessor(const VirtualPlacer& virtual_placer, + const std::unordered_set& nodes_to_preserve, + GraphDef* graph, NodeMap* node_map) + : virtual_placer_(virtual_placer), + nodes_to_preserve_(nodes_to_preserve), + graph_(graph), + node_map_(node_map) {} protected: NodeDef* AddNodePermConst(const string& name, const string& device, @@ -122,7 +126,6 @@ class GraphProcessor { node_map_->AddNode(name, node); node->set_name(name); node->set_op("Const"); - node->set_device(device); AttrValue attr_data_type; attr_data_type.set_type(DT_INT32); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -133,6 +136,13 @@ class GraphProcessor { } tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); node->mutable_attr()->insert({"value", attr_tensor}); + string device_name; + if (device.empty()) { + device_name = virtual_placer_.get_canonical_device_name(*node); + } else { + device_name = device; + } + node->set_device(device_name); return node; } @@ -142,7 +152,6 @@ class GraphProcessor { node_map_->AddNode(name, node); node->set_name(name); node->set_op("Const"); - node->set_device(device); AttrValue attr_data_type; attr_data_type.set_type(dtype); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -151,6 +160,13 @@ class GraphProcessor { tensor.scalar()() = value; tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); node->mutable_attr()->insert({"value", attr_tensor}); + string device_name; + if (device.empty()) { + device_name = virtual_placer_.get_canonical_device_name(*node); + } else { + device_name = device; + } + node->set_device(device_name); return node; } @@ -159,7 +175,6 @@ class GraphProcessor { node_map_->AddNode(name, node); node->set_name(name); node->set_op("Const"); - node->set_device(device); AttrValue attr_data_type; attr_data_type.set_type(DT_INT32); node->mutable_attr()->insert({"dtype", attr_data_type}); @@ -172,26 +187,37 @@ class GraphProcessor { } tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); node->mutable_attr()->insert({"value", attr_tensor}); + string device_name; + if (device.empty()) { + device_name = virtual_placer_.get_canonical_device_name(*node); + } else { + device_name = device; + } + node->set_device(device_name); return node; } + const VirtualPlacer& virtual_placer_; + const std::unordered_set& nodes_to_preserve_; GraphDef* graph_; NodeMap* node_map_; - const std::unordered_set& nodes_to_preserve_; }; struct OptimizeContext { OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map, + const VirtualPlacer& virtual_placer, const std::unordered_set& nodes_to_preserve, bool is_in_frame) : graph(graph), node(node), node_map(node_map), + virtual_placer(virtual_placer), nodes_to_preserve(nodes_to_preserve), is_in_frame(is_in_frame) {} GraphDef* graph; NodeDef* node; NodeMap* node_map; + const VirtualPlacer& virtual_placer; const std::unordered_set& nodes_to_preserve; bool is_in_frame; }; @@ -199,8 +225,8 @@ struct OptimizeContext { class NodeProcessor : public GraphProcessor { public: explicit NodeProcessor(const OptimizeContext& opt_cxt) - : GraphProcessor(opt_cxt.graph, opt_cxt.node_map, - opt_cxt.nodes_to_preserve), + : GraphProcessor(opt_cxt.virtual_placer, opt_cxt.nodes_to_preserve, + opt_cxt.graph, opt_cxt.node_map), node_(opt_cxt.node), is_in_frame_(opt_cxt.is_in_frame) {} virtual ~NodeProcessor() {} @@ -257,7 +283,25 @@ class NodeProcessor : public GraphProcessor { } virtual bool ShouldProcess() const { - return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs(); + return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() && + IsOnGPU(); + } + + virtual bool IsOnGPU() const { + string device_name; + if (node_->device().empty()) { + device_name = virtual_placer_.get_canonical_device_name(*node_); + } else { + device_name = node_->device(); + } + string device; + string not_used; + if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) && + (StringPiece(str_util::Lowercase(device))) + .contains(str_util::Lowercase(DEVICE_GPU))) { + return true; + } + return false; } void UpdateAttrDataFormat() { @@ -536,6 +580,9 @@ class BiasAddGradProcessor : public NodeProcessor { if (MustPreserve()) { return false; } + if (!IsOnGPU()) { + return false; + } auto input = node_map_->GetNode(node_->input(0)); if (input) { if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) { @@ -556,7 +603,7 @@ class Conv2DProcessor : public NodeProcessor { protected: bool ShouldProcess() const override { return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() && - (!IsGemmUsed() || no_gemm_); + (!IsGemmUsed() || no_gemm_) && IsOnGPU(); } TensorShapeProto GetShape(const string& input_name) const { @@ -693,7 +740,7 @@ class AgnosticNodeProcessor : public NodeProcessor { protected: bool ShouldProcess() const override { return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && - IsNodeAfterNCHWToNHWC(); + IsNodeAfterNCHWToNHWC() && IsOnGPU(); } bool IsNodeAfterNCHWToNHWC() const { @@ -746,7 +793,8 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && (Is4DOperateWithND(4) || Is4DOperateWithScalar() || - Is4DOperateWithVector()); + Is4DOperateWithVector()) && + IsOnGPU(); } std::vector GetInputPos() const override { @@ -855,7 +903,7 @@ class ConcatProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && - IsNodeAfterNCHWToNHWC() && IsAlongDimC(); + IsNodeAfterNCHWToNHWC() && IsAlongDimC() && IsOnGPU(); } std::vector GetInputPos() const override { @@ -920,7 +968,7 @@ class PadProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && - IsNodeAfterNCHWToNHWC() && PaddingSupported(); + IsNodeAfterNCHWToNHWC() && PaddingSupported() && IsOnGPU(); } Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); } @@ -1132,7 +1180,8 @@ class SqueezeProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { return !MustPreserve() && IsDimsN(*node_, 2) && HasOutputs() && - IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW(); + IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() && + IsOnGPU(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); } @@ -1183,7 +1232,7 @@ class SumProcessor : public AgnosticNodeProcessor { auto input0 = node_map_->GetNode(node_->input(0)); return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && - IsAlongDimNHW(); + IsAlongDimNHW() && IsOnGPU(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); } @@ -1243,42 +1292,41 @@ class SumProcessor : public AgnosticNodeProcessor { class DataLayoutOptimizer : GraphProcessor { public: explicit DataLayoutOptimizer( - LayoutOptimizer::TuningConfig config, - const std::unordered_set& nodes_to_preserve, - const string& default_device, GraphDef* graph, NodeMap* node_map) - : GraphProcessor(graph, node_map, nodes_to_preserve), - config_(config), - default_device_(default_device) {} + const VirtualPlacer& virtual_placer, + const LayoutOptimizer::TuningConfig& config, + const std::unordered_set& nodes_to_preserve, GraphDef* graph, + NodeMap* node_map) + : GraphProcessor(virtual_placer, nodes_to_preserve, graph, node_map), + config_(config) {} Status Optimize() { - LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size(); + VLOG(1) << "Number of nodes for original graph: " << graph_->node_size(); TF_RETURN_IF_ERROR(Expand()); - LOG(INFO) << "Number of nodes after Expand: " << graph_->node_size(); + VLOG(1) << "Number of nodes after Expand: " << graph_->node_size(); TF_RETURN_IF_ERROR(Collapse()); - LOG(INFO) << "Number of nodes after Collapse: " << graph_->node_size(); + VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size(); return Status::OK(); } private: NodeDef* AddNodePermNHWCToNCHW() { - return AddNodePermConst(kPermNHWCToNCHW, default_device_, {0, 3, 1, 2}); + return AddNodePermConst(kPermNHWCToNCHW, "", {0, 3, 1, 2}); } NodeDef* AddNodePermNCHWToNHWC() { - return AddNodePermConst(kPermNCHWToNHWC, default_device_, {0, 2, 3, 1}); + return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1}); } NodeDef* AddNodeConcatConst() { - return AddNodeConstScalar(kConcatConst, default_device_, DT_INT32, 1); + return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1); } NodeDef* AddNodeGatherAxisConst() { - return AddNodeConstScalar(kGatherAxisConst, default_device_, DT_INT32, 0); + return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0); } NodeDef* AddNodeReductionConst() { - return GraphProcessor::AddNodeReductionConst(kReductionConst, - default_device_); + return GraphProcessor::AddNodeReductionConst(kReductionConst, ""); } // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. @@ -1295,8 +1343,8 @@ class DataLayoutOptimizer : GraphProcessor { ops_format_supported.end()) { auto node = graph_->mutable_node(i); bool is_in_frame = !frames[node].empty(); - OptimizeContext opt_cxt(graph_, node, node_map_, nodes_to_preserve_, - is_in_frame); + OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_, + nodes_to_preserve_, is_in_frame); std::unique_ptr node_processor; if (IsAvgPoolGrad(*node)) { node_processor.reset(new AvgPoolGradProcessor(opt_cxt)); @@ -1343,8 +1391,8 @@ class DataLayoutOptimizer : GraphProcessor { ops_format_agnostic.end()) { auto node = graph_->mutable_node(i); bool is_in_frame = !frames[node].empty(); - OptimizeContext opt_cxt(graph_, node, node_map_, nodes_to_preserve_, - is_in_frame); + OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_, + nodes_to_preserve_, is_in_frame); std::unique_ptr node_processor; if (IsAddN(*node)) { node_processor.reset(new AddNProcessor(opt_cxt)); @@ -1419,8 +1467,7 @@ class DataLayoutOptimizer : GraphProcessor { return Status::OK(); } - LayoutOptimizer::TuningConfig config_; - string default_device_; + const LayoutOptimizer::TuningConfig& config_; }; int GetNumTranspose(const GraphDef& graph) { @@ -1430,7 +1477,7 @@ int GetNumTranspose(const GraphDef& graph) { number++; } } - LOG(INFO) << "Number of Transpose nodes: " << number; + VLOG(1) << "Number of Transpose nodes: " << number; return number; } @@ -1455,7 +1502,6 @@ int GetNumGPUs(const Cluster& cluster) { Status LayoutOptimizer::Tune(const GrapplerItem& item, const GraphProperties& graph_properties, - const string& default_device, const TuningConfig& config, GraphDef* output) { auto status = graph_properties.AnnotateOutputShapes(output); if (!status.ok()) { @@ -1463,8 +1509,8 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item, return status; } NodeMap node_map(output); - DataLayoutOptimizer layout_optimizer(config, nodes_to_preserve_, - default_device, output, &node_map); + DataLayoutOptimizer layout_optimizer(*virtual_placer_, config, + nodes_to_preserve_, output, &node_map); status = layout_optimizer.Optimize(); return status; } @@ -1477,6 +1523,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } + virtual_placer_.reset(new VirtualPlacer(cluster)); nodes_to_preserve_ = item.NodesToPreserve(); GraphProperties graph_properties(item); auto status = graph_properties.InferStatically(); @@ -1487,20 +1534,13 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TuningConfig config; config.no_gemm = false; - string default_device = "/job:localhost/replica:0/task:0/cpu:0"; - if (cluster) { - if (!cluster->GetDevices().empty()) { - default_device = cluster->GetDevices().begin()->first; - } - } - - status = Tune(item, graph_properties, default_device, config, output); + status = Tune(item, graph_properties, config, output); // This is based on an empirical observation that if the introduced Transpose // nodes is more than 30, not using GEMM implementation would result in better // performance. if (status.ok() && GetNumTranspose(*output) > 30) { config.no_gemm = true; - status = Tune(item, graph_properties, default_device, config, output); + status = Tune(item, graph_properties, config, output); } if (!status.ok()) { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h index f5dd70356a3..357205828dd 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.h +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_ #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" namespace tensorflow { @@ -47,10 +48,10 @@ class LayoutOptimizer : public GraphOptimizer { const GraphDef& optimize_output, double result) override; private: + std::unique_ptr virtual_placer_; std::unordered_set nodes_to_preserve_; Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties, - const string& default_device, const TuningConfig& config, - GraphDef* output); + const TuningConfig& config, GraphDef* output); }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 5d2d90b193f..d4ab42ad60f 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -39,6 +39,11 @@ class LayoutOptimizerTest : public ::testing::Test { Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, const string& padding) { + return SimpleConv2D(s, input_size, filter_size, padding, ""); + } + + Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, + const string& padding, const string& device) { int batch_size = 128; int input_height = input_size; int input_width = input_size; @@ -59,8 +64,8 @@ class LayoutOptimizerTest : public ::testing::Test { Output filter = ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); - Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter, - {1, stride, stride, 1}, padding); + Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input, + filter, {1, stride, stride, 1}, padding); return conv; } @@ -278,6 +283,64 @@ TEST_F(LayoutOptimizerTest, PreserveFetch) { EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC"); } +TEST_F(LayoutOptimizerTest, EmptyDevice) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto conv_node = node_map.GetNode("Conv2D"); + EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW"); +} + +TEST_F(LayoutOptimizerTest, GPUDevice) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = + SimpleConv2D(&s, 3, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto conv_node = node_map.GetNode("Conv2D"); + EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW"); +} + +TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = + SimpleConv2D(&s, 3, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto conv_node = node_map.GetNode("Conv2D"); + EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC"); +} + +TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID", "/CPU:0"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto conv_node = node_map.GetNode("Conv2D"); + EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC"); +} + } // namespace } // namespace grappler } // namespace tensorflow From 5548cfc2eda3614a318e04cd922512be99aefefe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2017 16:45:54 -0800 Subject: [PATCH 06/15] [XLA] Add a convenience function that returns a platform with the given name. PiperOrigin-RevId: 176715886 --- .../compiler/xla/service/platform_util.cc | 22 +++++++++++++++++++ .../compiler/xla/service/platform_util.h | 8 +++++++ 2 files changed, 30 insertions(+) diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 3a1818de82d..63f3bfb36ce 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -94,6 +94,28 @@ PlatformUtil::GetSupportedPlatforms() { platforms_string.c_str()); } +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + using tensorflow::str_util::Lowercase; + string platform_str = Lowercase(platform_name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) == platform_str) { + return platform; + } + } + return InvalidArgument("platform %s not found", platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index eac57370308..a59d4ffe87f 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -39,6 +42,11 @@ class PlatformUtil { // default platform. Otherwise returns an error. static StatusOr GetDefaultPlatform(); + // Returns the platform according to the given name. Returns error if there is + // no such platform. + static StatusOr GetPlatform( + const string& platform_name); + // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. From 58d0e5b6c3cf48acb68c87d2deca0e304b075b1a Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Wed, 22 Nov 2017 17:18:52 -0800 Subject: [PATCH 07/15] Add persistent GradientTape support Added two simple tests for persistent tapes and did a manual test that calling "del" on gradient tape releases all tensors. Also: - Add missing Py_DECREF to error case in MakeTensorIDList - Make a couple error messages more descriptive PiperOrigin-RevId: 176718477 --- tensorflow/c/eager/tape.h | 66 +++++++++++++++++------ tensorflow/python/eager/backprop.py | 38 +++++++++++-- tensorflow/python/eager/backprop_test.py | 31 +++++++++++ tensorflow/python/eager/pywrap_tfe.h | 3 +- tensorflow/python/eager/pywrap_tfe_src.cc | 13 +++-- tensorflow/python/eager/tape.py | 4 +- 6 files changed, 128 insertions(+), 27 deletions(-) diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 84b40a18198..f52248e7d56 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -106,6 +106,12 @@ class VSpace { // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; + + // Lets this VSpace know that it can release resources held by the + // `backward_function`, It will not be called again. + // `backward_function` must not be null. + virtual void ReleaseBackwardFunction( + BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -113,7 +119,11 @@ class VSpace { template class GradientTape { public: - GradientTape() {} + // If `persistent` is true, GradientTape will not eagerly delete backward + // functions (and hence the tensors they keep alive). Instead, everything + // is deleted in ~GradientTape. Persistent GradientTapes are useful when + // users want to compute multiple gradients over the same tape. + GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { pair.second.backward_function_deleter(); @@ -150,6 +160,10 @@ class GradientTape { // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. std::unordered_map tensor_usage_; + + // If true, all activations are deleted in the first call to ComputeGradient. + // Else, only when this is destructed. + bool persistent_; }; // Template instantiations here @@ -279,11 +293,16 @@ struct BackpropInitialState { std::unordered_map op_missing_tensor; }; +// If `persistent_tape` is true, op_tape is not changed and none of the +// backwards functions are deleted. +// If `persistent_tape` is false, op_tape is cleared and backwards functions +// not needed for gradient computation are deleted. Backwards functions that +// are needed, are copied and returned in BackpropInitialState. template BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape op_tape, - const std::unordered_set& sources_set) { + OpTape* op_tape, + const std::unordered_set& sources_set, bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { @@ -298,9 +317,9 @@ BackpropInitialState PrepareBackprop( continue; } int64 op_id = op_id_it->second; - auto op_it = op_tape.find(op_id); + auto op_it = op_tape->find(op_id); auto result_op_it = result.op_tape.find(op_id); - if (op_id == -1 || op_it == op_tape.end() || + if (op_id == -1 || op_it == op_tape->end() || result_op_it != result.op_tape.end()) { continue; } @@ -317,7 +336,9 @@ BackpropInitialState PrepareBackprop( } } } - op_tape.erase(op_it); + if (!persistent_tape) { + op_tape->erase(op_it); + } } for (auto& pair : result.tensor_usage_counts) { auto it = tensor_tape.find(pair.first); @@ -325,9 +346,15 @@ BackpropInitialState PrepareBackprop( result.op_missing_tensor[it->second] += 1; } } - // Call destructors for all unneeded gradient functions. - for (const auto& op_pair : op_tape) { - op_pair.second.backward_function_deleter(); + if (!persistent_tape) { + // Call destructors for all unneeded gradient functions and + // clear the op_tape. We can clear the tape because ownership of + // backward functions that will be used for gradient computation + // has been transfered to `result`. + for (const auto& op_pair : *op_tape) { + op_pair.second.backward_function_deleter(); + } + op_tape->clear(); } return result; } @@ -369,7 +396,8 @@ Status InitialGradients( auto op_it = op_tape.find(tensor_it->second); if (op_it == op_tape.end()) { return errors::Internal( - "Internal state of the gradient tape is invalid."); + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); } bool found = false; for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { @@ -383,7 +411,8 @@ Status InitialGradients( } if (!found) { return errors::Internal( - "Internal state of the gradient tape is invalid."); + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); } } else { // No record of the target tensor found on the tape, so no gradient @@ -415,17 +444,19 @@ Status GradientTape::ComputeGradient( std::unordered_set sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( - target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set); + target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); std::unordered_map> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, tensor_tape_, state.op_tape, state.tensor_usage_counts, &gradients); - auto cleanup = [&state]() { - // Release all backprop functions - for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + auto cleanup = [this, &state]() { + if (!persistent_) { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(); + } } }; if (!s.ok()) { @@ -484,6 +515,9 @@ Status GradientTape::ComputeGradient( std::vector in_gradients; Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } if (!s.ok()) { cleanup(); return s; diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 25f7ae785e6..0144f3b1e59 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -798,13 +798,41 @@ class GradientTape(object): grad = g.gradient(y, [x])[0] assert grad.numpy() == 6.0 ``` + + By default, the resources held by a GradientTape are released as soon as + GradientTape.gradient() method is called. However, if one need to compute + multiple gradients over the same computation, she can create a persistent + GradientTape. Persistent tapes allow multiple calls to the gradient() method + and release resources when the tape object is destructed. + + Example usage: + + ```python + with tfe.GradientTape(persistent=True) as g: + x = tf.constant(3.0) + g.watch(x) + y = x * x + z = y * y + dz_dx = g.gradient(z, [x])[0] + assert dz_dx.numpy() == 108.0 # 4*x^3 at x = 3 + dy_dx = g.gradient(y, [x])[0] + assert dy_dx.numpy() == 6.0 + del g # Drop the reference to the tape """ - def __init__(self): + def __init__(self, persistent=False): + """Creates a new GradientTape. + + Args: + persistent: Boolean controlling whether a persistent gradient tape + is created. Must be True or False. + + """ self._tape = None + self._persistent = persistent def __enter__(self): - tape.push_new_tape() + tape.push_new_tape(persistent=self._persistent) return self def __exit__(self, typ, value, traceback): @@ -838,12 +866,14 @@ class GradientTape(object): than once. """ if self._tape is None: - raise RuntimeError("GradientTape.gradient can only be called once, and " + raise RuntimeError("GradientTape.gradient can only be called once " + "on non-persistent tapes, and " "only when the context manager has exited.") sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable) else x for x in sources] grad = imperative_grad.imperative_grad( _default_vspace, self._tape, [target], sources) - self._tape = None + if not self._persistent: + self._tape = None return grad diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index e18ebba7857..9816dd022eb 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -314,6 +314,37 @@ class BackpropTest(test.TestCase): RuntimeError, 'GradientTape.gradient can only be called once'): g.gradient(y, [x]) + def testPersistentTape(self): + with backprop.GradientTape(persistent=True) as g: + x = constant_op.constant(3.0) + g.watch(x) + y = x * x + z = y * y + dz_dx = g.gradient(z, [x])[0] + self.assertEqual(dz_dx.numpy(), 4*3*3*3) + dy_dx = g.gradient(y, [x])[0] + self.assertEqual(dy_dx.numpy(), 2*3) + del g + + def testPersistentNestedTape(self): + with backprop.GradientTape(persistent=True) as g: + x = constant_op.constant(3.0) + g.watch(x) + y = x * x + with backprop.GradientTape(persistent=True) as gg: + gg.watch(y) + z = 2 * y + for _ in range(2): + inner_grad = gg.gradient(z, [y])[0] + self.assertEqual(inner_grad.numpy(), 2.0) + y += inner_grad + del gg + grad = g.gradient(y, [x])[0] + self.assertEqual(grad.numpy(), 6.0) + grad = g.gradient(z, [x])[0] + self.assertEqual(grad.numpy(), 12.0) + del g + def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') with backprop.GradientTape() as g: diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index f96245f7a53..a33b17ada6f 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -88,7 +88,8 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); // Pushes a new tape into the thread-local stack. -void TFE_Py_TapeStackPushNew(); +// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False +void TFE_Py_TapeStackPushNew(PyObject* persistent); // Pops the tape from the top of the stack and returns it. PyObject* TFE_Py_TapeStackPop(); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 0a0749fd4b9..ce823cb5679 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -469,7 +469,8 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) { class GradientTape : public tensorflow::eager::GradientTape { public: - GradientTape() {} + explicit GradientTape(bool persistent) + : tensorflow::eager::GradientTape(persistent) {} void WatchVariable(PyObject* v) { watched_variables_.insert(v); @@ -557,11 +558,11 @@ std::vector* GetTapeStack() { } #endif -void TFE_Py_TapeStackPushNew() { +void TFE_Py_TapeStackPushNew(PyObject* persistent) { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); - tape->tape = new GradientTape(); + tape->tape = new GradientTape(persistent == Py_True); GetTapeStack()->push_back(tape); } @@ -704,6 +705,7 @@ std::vector MakeTensorIDList(PyObject* tensors) { PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); list.push_back(FastTensorId(tensor)); if (PyErr_Occurred()) { + Py_DECREF(seq); return list; } } @@ -889,7 +891,6 @@ class PyVSpace : public tensorflow::eager::VSpace { PyObject* py_result = PyEval_CallObject( reinterpret_cast(backward_function), grads); Py_DECREF(grads); - Py_DECREF(backward_function); if (py_result == nullptr) { return tensorflow::errors::Internal("gradient function threw exceptions"); } @@ -917,6 +918,10 @@ class PyVSpace : public tensorflow::eager::VSpace { return tensorflow::Status::OK(); } + void ReleaseBackwardFunction(PyObject* backward_function) const final { + Py_DECREF(backward_function); + } + void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } private: diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 440c84b7ea9..14b5238f740 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -33,9 +33,9 @@ class Tape(object): return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) -def push_new_tape(): +def push_new_tape(persistent=False): """Pushes a new tape onto the tape stack.""" - pywrap_tensorflow.TFE_Py_TapeStackPushNew() + pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent) def watch(tensor): From f25abbfb25441bec198ca7517485fbab63f07be1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2017 17:40:42 -0800 Subject: [PATCH 08/15] Minor cleanup: remove unnecessary GetCudaContext. Note that there is no protection against a caller of CUDAExecutor::Launch from accidentally passing a Stream associated with the wrong CUDAExecutor. This is no different from any other CUDAExecutor methods that take a Stream argument, where we similarly have no such protection. The main caller is Stream::ThenLaunch, which necessarily calls Launch on the correct corresponding CUDAExecutor. Other callers use a similar pattern. PiperOrigin-RevId: 176719918 --- .../stream_executor/cuda/cuda_gpu_executor.cc | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 6c522264e1d..64d14f29dfe 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -108,11 +108,6 @@ static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase *gpu_mem) { return AsCudaDevicePtr(*gpu_mem); } -static CudaContext* GetCudaContext(Stream *stream) { - return static_cast(stream->parent()->implementation()) - ->cuda_context(); -} - CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) { CHECK(cuda_exec != nullptr); return cuda_exec->cuda_context(); @@ -380,11 +375,11 @@ bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, void **kernel_params = const_cast(args.argument_addresses().data()); - if (!CUDADriver::LaunchKernel(GetCudaContext(stream), cufunc, block_dims.x, - block_dims.y, block_dims.z, thread_dims.x, - thread_dims.y, thread_dims.z, - args.number_of_shared_bytes(), custream, - kernel_params, nullptr /* = extra */)) { + if (!CUDADriver::LaunchKernel(context_, cufunc, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, + thread_dims.z, args.number_of_shared_bytes(), + custream, kernel_params, + nullptr /* = extra */)) { LOG(ERROR) << "failed to launch CUDA kernel with args: " << args.number_of_arguments() << "; thread dim: " << thread_dims.ToString() From 4e5534d3d35a72b87902212e2847ca2871cc7b75 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 22 Nov 2017 17:53:40 -0800 Subject: [PATCH 09/15] Prevented a couple of memory leaks in the code generated by swig PiperOrigin-RevId: 176720721 --- tensorflow/python/BUILD | 5 +- tensorflow/python/grappler/cluster.i | 87 ++++++++++++++-------- tensorflow/python/grappler/cluster.py | 4 +- tensorflow/python/grappler/cost_analyzer.i | 11 +-- tensorflow/python/grappler/item.i | 44 ++++++++--- tensorflow/python/grappler/tf_optimizer.i | 14 +--- tensorflow/python/grappler/tf_optimizer.py | 6 +- 7 files changed, 110 insertions(+), 61 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 54c43c1337a..9d3974b98e5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4350,7 +4350,10 @@ py_library( ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [":pywrap_tensorflow_internal"], + deps = [ + ":pywrap_tensorflow_internal", + ":tf_cluster", + ], ) py_test( diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 5a7cdf26f8e..18fda345e6d 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -14,6 +14,14 @@ limitations under the License. ==============================================================================*/ %include "tensorflow/python/platform/base.i" +%include +%include "item.i" + +// Wrap the cluster into an object that swig can manipulate. This ensures it will call the object +// destructor upon garbage collection instead of leaking memory. +struct GCluster { + std::shared_ptr cluster_; +}; %{ #include "tensorflow/core/protobuf/device_properties.pb.h" @@ -72,6 +80,7 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) { } %{ +#include #include #include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/clusters/single_machine.h" @@ -82,39 +91,56 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) { #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/protobuf/device_properties.pb.h" -static tensorflow::grappler::Cluster* TF_NewCluster( - bool allow_soft_placement, - bool disable_detailed_stats, TF_Status* out_status) { - int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); - int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();; +// Provide the implementation of the GCluster struct here. +struct GCluster { + GCluster() {} + GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {} + + tensorflow::grappler::Cluster* operator->() const { + return cluster_.get(); + } + tensorflow::grappler::Cluster* get() const { + return cluster_.get(); + } + bool is_none() const { + return cluster_.get() == nullptr; + } + + std::shared_ptr cluster_; +}; + + +static GCluster TF_NewCluster(bool allow_soft_placement, + bool disable_detailed_stats, TF_Status* out_status) { + int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); + int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); int timeout_s = 60 * 10; - tensorflow::grappler::Cluster* cluster = + tensorflow::grappler::Cluster* cluster_ = new tensorflow::grappler::SingleMachine( timeout_s, num_cpu_cores, num_gpus); - cluster->DisableDetailedStats(disable_detailed_stats); - cluster->AllowSoftPlacement(allow_soft_placement); - tensorflow::Status status = cluster->Provision(); + cluster_->DisableDetailedStats(disable_detailed_stats); + cluster_->AllowSoftPlacement(allow_soft_placement); + tensorflow::Status status = cluster_->Provision(); tensorflow::Set_TF_Status_from_Status(out_status, status); - return cluster; + return GCluster(cluster_); } -static tensorflow::grappler::Cluster* TF_NewVirtualCluster( +static GCluster TF_NewVirtualCluster( const std::vector& named_devices, TF_Status* out_status) { std::unordered_map devices; for (const auto& named_device : named_devices) { devices[named_device.name()]= named_device.properties(); } - tensorflow::grappler::Cluster* cluster = + tensorflow::grappler::Cluster*cluster_ = new tensorflow::grappler::VirtualCluster(devices); - tensorflow::Status status = cluster->Provision(); + tensorflow::Status status = cluster_->Provision(); tensorflow::Set_TF_Status_from_Status(out_status, status); - return cluster; + return GCluster(cluster_); } -static void TF_DeleteCluster(tensorflow::grappler::Cluster* cluster) { +static void TF_ShutdownCluster(GCluster cluster) { cluster->Shutdown(); - delete cluster; } tensorflow::Status _GetOpPerformanceDataAndRunTime( @@ -136,7 +162,7 @@ tensorflow::Status _GetOpPerformanceDataAndRunTime( return tensorflow::Status::OK(); } -static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) { +static PyObject* TF_ListDevices(GCluster cluster) { const std::unordered_map& devices = cluster->GetDevices(); PyGILState_STATE gstate = PyGILState_Ensure(); PyObject* result = PyList_New(devices.size()); @@ -156,13 +182,13 @@ static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) { } static PyObject* TF_MeasureCosts( - const tensorflow::grappler::GrapplerItem* item, - tensorflow::grappler::Cluster* cluster, + GItem item, + GCluster cluster, bool generate_timeline, TF_Status* out_status) { tensorflow::OpPerformanceList op_performance_data; tensorflow::StepStats step_stats; - tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster, 10, 0); + tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), 10, 0); tensorflow::grappler::Costs costs; tensorflow::Status status = _GetOpPerformanceDataAndRunTime( @@ -223,10 +249,10 @@ static PyObject* TF_MeasureCosts( static PyObject* TF_DeterminePeakMemoryUsage( - const tensorflow::grappler::GrapplerItem* item, - tensorflow::grappler::Cluster* cluster, + GItem item, + GCluster cluster, TF_Status* out_status) { - if (!item || !cluster) { + if (item.is_none() || cluster.is_none()) { tensorflow::Status status(tensorflow::error::Code::INTERNAL, "You need both a cluster and an item to determine peak memory usage"); tensorflow::Set_TF_Status_from_Status(out_status, status); @@ -236,7 +262,7 @@ static PyObject* TF_DeterminePeakMemoryUsage( tensorflow::Status status; if (cluster->DetailedStatsEnabled()) { - status = memory.InferDynamically(cluster); + status = memory.InferDynamically(cluster.get()); } else { status = memory.InferStatically(cluster->GetDevices()); } @@ -274,18 +300,17 @@ static PyObject* TF_DeterminePeakMemoryUsage( %} // Wrap these functions. - -static tensorflow::grappler::Cluster* TF_NewCluster( +static GCluster TF_NewCluster( bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status); -static tensorflow::grappler::Cluster* TF_NewVirtualCluster( +static GCluster TF_NewVirtualCluster( const std::vector& named_devices, TF_Status* out_status); -static void TF_DeleteCluster(tensorflow::grappler::Cluster* cluster); -static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster); +static void TF_ShutdownCluster(GCluster cluster); +static PyObject* TF_ListDevices(GCluster cluster); static PyObject* TF_MeasureCosts( - const tensorflow::grappler::GrapplerItem* item, tensorflow::grappler::Cluster* cluster, + GItem item, GCluster cluster, bool generate_timeline, TF_Status* out_status); static PyObject* TF_DeterminePeakMemoryUsage( - const tensorflow::grappler::GrapplerItem* item, tensorflow::grappler::Cluster* cluster, + GItem item, GCluster cluster, TF_Status* out_status); diff --git a/tensorflow/python/grappler/cluster.py b/tensorflow/python/grappler/cluster.py index 496f5255b95..cf795fddb71 100644 --- a/tensorflow/python/grappler/cluster.py +++ b/tensorflow/python/grappler/cluster.py @@ -46,6 +46,7 @@ class Cluster(object): the local machine. """ self._tf_cluster = None + self._generate_timeline = not disable_timeline with errors.raise_exception_on_not_ok_status() as status: if devices is None: self._tf_cluster = tf_cluster.TF_NewCluster( @@ -54,11 +55,10 @@ class Cluster(object): devices_serialized = [device.SerializeToString() for device in devices] self._tf_cluster = tf_cluster.TF_NewVirtualCluster( devices_serialized, status) - self._generate_timeline = not disable_timeline def __del__(self): if self._tf_cluster is not None: - tf_cluster.TF_DeleteCluster(self._tf_cluster) + tf_cluster.TF_ShutdownCluster(self._tf_cluster) @property def tf_cluster(self): diff --git a/tensorflow/python/grappler/cost_analyzer.i b/tensorflow/python/grappler/cost_analyzer.i index 0318ff762c9..4c0953435ba 100644 --- a/tensorflow/python/grappler/cost_analyzer.i +++ b/tensorflow/python/grappler/cost_analyzer.i @@ -15,6 +15,7 @@ limitations under the License. %include "tensorflow/python/lib/core/strings.i" %include "tensorflow/python/platform/base.i" +%include "cluster.i" %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { char* c_string; @@ -42,8 +43,8 @@ limitations under the License. %} %{ -string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool -per_node_report, tensorflow::grappler::Cluster* cluster) { +string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report, + GCluster cluster) { tensorflow::grappler::ItemConfig cfg; cfg.apply_optimizations = false; std::unique_ptr item = @@ -53,7 +54,7 @@ per_node_report, tensorflow::grappler::Cluster* cluster) { } string suffix; - tensorflow::grappler::CostAnalyzer analyzer(*item, cluster, suffix); + tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix); std::stringstream os; analyzer.GenerateReport(os, per_node_report); @@ -62,5 +63,5 @@ per_node_report, tensorflow::grappler::Cluster* cluster) { %} -string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool - per_node_report, tensorflow::grappler::Cluster* cluster); +string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report, + GCluster cluster); diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index 2fa502b81da..7dd79f7c82c 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +%include %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { char* c_string; Py_ssize_t py_size; @@ -30,7 +31,12 @@ limitations under the License. $1 = &temp; } -%newobject TF_NewItem; +// Wrap the item into an object that swig can manipulate. This ensures it will call the object +// destructor upon garbage collection instead of leaking memory. +struct GItem { + std::shared_ptr item_; +}; + %{ #include @@ -42,8 +48,26 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" -static tensorflow::grappler::GrapplerItem* TF_NewItem( +// Provide the implementation fo the GItem struct here. +struct GItem { + GItem() {} + GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {} + + tensorflow::grappler::GrapplerItem* operator->() const { + return item_.get(); + } + const tensorflow::grappler::GrapplerItem& operator*() const { + return *item_.get(); + } + bool is_none() const { + return item_.get() == nullptr; + } + std::shared_ptr item_; +}; + +static GItem TF_NewItem( const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, bool ignore_user_placement, TF_Status* out_status) { if (meta_graph.collection_def().count("train_op") == 0) { @@ -65,11 +89,11 @@ static tensorflow::grappler::GrapplerItem* TF_NewItem( return nullptr; } tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK()); - return item.release(); + return GItem(item.release()); } -static std::vector TF_IdentifyImportantOps(const tensorflow::grappler::GrapplerItem* item) { - if (!item) { +static std::vector TF_IdentifyImportantOps(GItem item) { + if (item.is_none()) { return {}; } @@ -91,8 +115,8 @@ static std::vector TF_IdentifyImportantOps(const tensorflow::grappler::G return ops; } -static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* item) { - if (!item) { +static PyObject* TF_GetOpProperties(GItem item) { + if (item.is_none()) { Py_RETURN_NONE; } tensorflow::grappler::GraphProperties properties(*item); @@ -125,8 +149,8 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it // Wrap these functions. -static tensorflow::grappler::GrapplerItem* TF_NewItem( +static GItem TF_NewItem( const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, bool ignore_user_placement, TF_Status* out_status); -static std::vector TF_IdentifyImportantOps(const tensorflow::grappler::GrapplerItem* item); -static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* item); +static std::vector TF_IdentifyImportantOps(GItem item); +static PyObject* TF_GetOpProperties(GItem item); diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index 3965c65bb90..f0dd4483a63 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -15,6 +15,7 @@ limitations under the License. %include "tensorflow/python/platform/base.i" +%include "cluster.i" %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { char* c_string; @@ -92,7 +93,7 @@ void DetectDevices(std::unordered_map* dev } PyObject* TF_OptimizeGraph( - tensorflow::grappler::Cluster* cluster, + GCluster cluster, const tensorflow::RewriterConfig& rewriter_config, const tensorflow::MetaGraphDef& metagraph, bool verbose, const string& graph_id, TF_Status* out_status) { @@ -102,17 +103,10 @@ PyObject* TF_OptimizeGraph( std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); - std::unique_ptr virtual_cluster; - if (cluster == nullptr) { - std::unordered_map device_map; - DetectDevices(&device_map); - virtual_cluster.reset(new tensorflow::grappler::VirtualCluster(device_map)); - cluster = virtual_cluster.get(); - } tensorflow::DeviceBase* cpu_device = nullptr; tensorflow::GraphDef out_graph; tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config); - tensorflow::Status status = optimizer.Optimize(cluster, *grappler_item, &out_graph); + tensorflow::Status status = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph); if (verbose) { optimizer.PrintResult(); } @@ -127,7 +121,7 @@ PyObject* TF_OptimizeGraph( // Wrap this function PyObject* TF_OptimizeGraph( - tensorflow::grappler::Cluster* cluster, + GCluster cluster, const tensorflow::RewriterConfig& rewriter_config, const tensorflow::MetaGraphDef& metagraph, bool verbose, const string& graph_id, TF_Status* out_status); diff --git a/tensorflow/python/grappler/tf_optimizer.py b/tensorflow/python/grappler/tf_optimizer.py index d430dd9e2f8..a73a4a98fc5 100644 --- a/tensorflow/python/grappler/tf_optimizer.py +++ b/tensorflow/python/grappler/tf_optimizer.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.core.framework import graph_pb2 from tensorflow.python import pywrap_tensorflow as tf_opt from tensorflow.python.framework import errors +from tensorflow.python.grappler import cluster as gcluster def OptimizeGraph(rewriter_config, @@ -30,8 +31,9 @@ def OptimizeGraph(rewriter_config, cluster=None): """Optimize the provided metagraph.""" with errors.raise_exception_on_not_ok_status() as status: - ret_from_swig = tf_opt.TF_OptimizeGraph(None if cluster is None else - cluster.tf_cluster, + if cluster is None: + cluster = gcluster.Cluster() + ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster, rewriter_config.SerializeToString(), metagraph.SerializeToString(), verbose, graph_id, status) From b76620aed0c02d01a823df57e06a67bc4c1424c0 Mon Sep 17 00:00:00 2001 From: James Keeling Date: Wed, 22 Nov 2017 17:58:34 -0800 Subject: [PATCH 10/15] Default to previously specified variables when minimizing with KfacOptimizer If no variables are specified to minimize or compute_gradients, the default was previously to use all trainable variables. However, KfacOptimizer has a list of variables it is able to train, so we should use that instead. PiperOrigin-RevId: 176720954 --- .../contrib/kfac/python/ops/optimizer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 98f8e7b2308..ecf7f3e4e5a 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -151,16 +151,24 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return self._fisher_est.damping def minimize(self, *args, **kwargs): - - if "var_list" not in kwargs: - kwargs["var_list"] = tf_variables.trainable_variables() - + kwargs["var_list"] = kwargs.get("var_list") or self.variables if set(kwargs["var_list"]) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + def compute_gradients(self, *args, **kwargs): + # args[1] could be our var_list + if len(args) > 1: + var_list = args[1] + else: + kwargs["var_list"] = kwargs.get("var_list") or self.variables + var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) + def apply_gradients(self, grads_and_vars, *args, **kwargs): """Applies gradients to variables. From ebd26397ab708242d22880f789b168eb16897691 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Wed, 22 Nov 2017 18:22:07 -0800 Subject: [PATCH 11/15] Do not convert layout for FusedBatchNormGrad if is_training is false (freeze mode), since NCHW is not supported on GPU in this case. PiperOrigin-RevId: 176722850 --- .../grappler/optimizers/layout_optimizer.cc | 14 +++++ .../optimizers/layout_optimizer_test.cc | 58 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 31c3ba68637..d25d9d99c51 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -714,10 +714,24 @@ class FusedBatchNormGradProcessor : public NodeProcessor { : NodeProcessor(opt_cxt) {} protected: + bool ShouldProcess() const override { + return NodeProcessor::ShouldProcess() && IsTraining(); + } + std::vector GetInputPos() const override { std::vector input_pos = {0, 1}; return input_pos; } + + private: + bool IsTraining() const { + if (node_->attr().find("is_training") != node_->attr().end()) { + if (node_->attr().at("is_training").b()) { + return true; + } + } + return false; + } }; class MaxPoolGradProcessor : public NodeProcessor { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index d4ab42ad60f..20a971629c8 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -114,6 +114,36 @@ class LayoutOptimizerTest : public ::testing::Test { return tensor; } + Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) { + int batch_size = 16; + int input_height = 8; + int input_width = 8; + int input_channels = 3; + TensorShape shape({batch_size, input_height, input_width, input_channels}); + Tensor data(DT_FLOAT, shape); + test::FillIota(&data, 1.0f); + Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data)); + Output y_backprop = + ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data)); + + TensorShape shape_vector({input_channels}); + Tensor data_vector(DT_FLOAT, shape_vector); + test::FillIota(&data_vector, 2.0f); + Output scale = + ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector)); + Output reserve1 = + ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector)); + Output reserve2 = + ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector)); + + ops::FusedBatchNormGrad::Attrs attrs; + attrs.is_training_ = is_training; + auto output = + ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop, + x, scale, reserve1, reserve2, attrs); + return output.x_backprop; + } + std::unique_ptr virtual_cluster_; }; @@ -341,6 +371,34 @@ TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) { EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC"); } +TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x_backprop = SimpleFusedBatchNormGrad(&s, true); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto conv_node = node_map.GetNode("FusedBatchNormGrad"); + EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW"); +} + +TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x_backprop = SimpleFusedBatchNormGrad(&s, false); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto conv_node = node_map.GetNode("FusedBatchNormGrad"); + EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC"); +} + } // namespace } // namespace grappler } // namespace tensorflow From 622a6ec6dc79c458aac03dafffe0f0fef48e9c01 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 22 Nov 2017 18:30:37 -0800 Subject: [PATCH 12/15] Replace assertAlmostEqual with assertAllClose in boosted_trees losses_test.py Calling assertAlmostEqual() with the places kwarg on numpy.ndarray leads to calling __round__ on numpy.ndarray, which is no consistently defined for all relevant platforms and numpy versions. PiperOrigin-RevId: 176723366 --- .../boosted_trees/python/utils/losses_test.py | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py index dde16426863..ccb8509c034 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.contrib.boosted_trees.python.utils import losses @@ -60,35 +58,27 @@ class LossesTest(test_util.TensorFlowTestCase): neg_loss = loss_for_negatives.eval() # For positive labels, points <= 0.3 get max loss of e. # For negative labels, these points have minimum loss of 1/e. - for i in range(2): - self.assertAlmostEqual(math.exp(1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(-1), neg_loss[i], places=4) + self.assertAllClose(np.exp(np.ones([2, 1])), pos_loss[:2], atol=1e-4) + self.assertAllClose(np.exp(-np.ones([2, 1])), neg_loss[:2], atol=1e-4) # For positive lables, p oints with predictions 0.7 and larger get minimum # loss value of 1/e. For negative labels, these points are wrongly # classified and get loss e. - for i in range(6, 10): - self.assertAlmostEqual(math.exp(-1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(1), neg_loss[i], places=4) + self.assertAllClose(np.exp(-np.ones([4, 1])), pos_loss[6:10], atol=1e-4) + self.assertAllClose(np.exp(np.ones([4, 1])), neg_loss[6:10], atol=1e-4) # Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where # y = 1/eps *x -1/(2eps), where x is the probability and label_m is either # 1 or -1 (for label of 0). - for i in range(2, 6): - self.assertAlmostEqual( - math.exp(-1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - pos_loss[i], - places=4) - self.assertAlmostEqual( - math.exp(1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - neg_loss[i], - places=4) + self.assertAllClose( + np.exp(-(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps)), + pos_loss[2:6], atol=1e-4) + self.assertAllClose( + np.exp(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps), + neg_loss[2:6], atol=1e-4) def test_per_example_squared_loss(self): - def _squared_loss(p, y): - return np.mean(1.0 * (p - y) * (p - y)) - labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32) weights = array_ops.ones([5, 1], dtypes.float32) predictions = np.array( @@ -99,9 +89,8 @@ class LossesTest(test_util.TensorFlowTestCase): predictions) loss = loss_tensor.eval() - for i in range(5): - self.assertAlmostEqual( - _squared_loss(labels[i], predictions[i]), loss[i], places=4) + self.assertAllClose( + np.square(labels[:5] - predictions[:5]), loss[:5], atol=1e-4) if __name__ == "__main__": From 34a69568752ef8badbe6aab5d1f568821c19e19c Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Wed, 22 Nov 2017 19:13:54 -0800 Subject: [PATCH 13/15] Fix flaky test. PiperOrigin-RevId: 176725659 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 995ce6d6546..c017cd9c777 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -331,7 +331,7 @@ py_test( py_test( name = "reader_dataset_ops_test", - size = "small", + size = "medium", srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", deps = [ From 059e35acc985e99e522ffe89df12cd357871309b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Nov 2017 19:28:47 -0800 Subject: [PATCH 14/15] Minor cleanup - replace users()[0] with users->front(). PiperOrigin-RevId: 176726299 --- tensorflow/compiler/xla/service/hlo_verifier.cc | 4 ++-- tensorflow/compiler/xla/service/while_loop_simplifier.cc | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index f2a739c1e24..15188c4057e 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -283,7 +283,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleSend(HloInstruction* send) override { TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users()[0]; + const HloInstruction* send_done = send->users().front(); TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); return CheckShape( @@ -301,7 +301,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleRecv(HloInstruction* recv) override { TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users()[0]; + const HloInstruction* recv_done = recv->users().front(); TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); return CheckShape(recv, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 8f335be7945..b38ee907d70 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -342,7 +342,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // // Careful: HloInstruction::operand_index returns the first index the // operand appears in, but it may appear more than once! - if (user->user_count() == 1 && user->users()[0] == while_body_root && + if (user->user_count() == 1 && user->users().front() == while_body_root && while_body_root->operand_index(user) == user->tuple_index() && std::count(while_body_root->operands().begin(), while_body_root->operands().end(), user) == 1) { @@ -444,7 +444,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // This is a GTE of an index that we've removed. Remove it from the // cloned computation. CHECK(user->user_count() == 0 || - user->user_count() == 1 && user->users()[0] == while_body_root) + user->user_count() == 1 && + user->users().front() == while_body_root) << "Instruction " << user->ToStringNoMetadata() << " should be unused (except by root of while body), but has " "users: {" From d73e8b36d1332723f5819d07f8c44e88c49c7cec Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Wed, 22 Nov 2017 21:20:41 -0800 Subject: [PATCH 15/15] Make PrefetchDataset saveable. PiperOrigin-RevId: 176732156 --- .../contrib/data/python/kernel_tests/BUILD | 12 ++ .../kernel_tests/prefetch_dataset_op_test.py | 39 +++++ tensorflow/core/kernels/BUILD | 1 + .../core/kernels/prefetch_dataset_op.cc | 141 +++++++++++++++++- 4 files changed, 186 insertions(+), 7 deletions(-) create mode 100644 tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index c017cd9c777..3280f1fc356 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -303,6 +303,18 @@ py_test( ], ) +py_test( + name = "prefetch_dataset_op_test", + size = "small", + srcs = ["prefetch_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "range_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py new file mode 100644 index 00000000000..3d120a3071e --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class PrefetchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, seed): + return dataset_ops.Dataset.range(100).prefetch(10).shuffle( + buffer_size=10, seed=seed, reshuffle_each_iteration=False) + + def testCore(self): + num_outputs = 100 + self.run_core_tests(lambda: self.build_dataset(10), + lambda: self.build_dataset(20), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b4a5a3c7966..b86739eea71 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6054,6 +6054,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/core/kernels/prefetch_dataset_op.cc b/tensorflow/core/kernels/prefetch_dataset_op.cc index 80592aa353a..1a6b7e078eb 100644 --- a/tensorflow/core/kernels/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/prefetch_dataset_op.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "tensorflow/core/kernels/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dataset.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" namespace tensorflow { @@ -37,14 +38,14 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); - *output = new Dataset(input, buffer_size); + *output = new Dataset(ctx, input, buffer_size); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(const DatasetBase* input, int64 buffer_size) - : input_(input), buffer_size_(buffer_size) { + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size) + : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) { input_->Ref(); } @@ -65,6 +66,18 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "PrefetchDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* buffer_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, buffer_size}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -119,7 +132,10 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { // Wake the prefetch thread, in case it has been waiting // for space in the buffer. - cond_var_.notify_one(); + // Also wake up threads from other calls to GetNext. + // TODO(mrry): Consider using different condition variables + // for GetNext and Prefetch. + cond_var_.notify_all(); return s; } else if (prefetch_thread_finished_) { *end_of_sequence = true; @@ -128,6 +144,69 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { } } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + // Acquire both locks to ensure that the prefetch thread and + // all GetNext threads are blocked. + mutex_lock parent_l(parent_mu_); + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("buffer_size"), buffer_.size())); + for (size_t i = 0; i < buffer_.size(); i++) { + auto& buffer_element = buffer_[i]; + TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status)); + if (buffer_element.status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("buffer[", i, "].size")), + buffer_element.value.size())); + for (size_t j = 0; j < buffer_element.value.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + strings::StrCat("buffer[", i, "][", j, "]"), + buffer_element.value[j])); + } + } + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock parent_l(parent_mu_); + mutex_lock l(mu_); + buffer_.clear(); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + size_t buffer_size; + { + int64 temp; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("buffer_size"), &temp)); + buffer_size = static_cast(temp); + } + for (size_t i = 0; i < buffer_size; i++) { + buffer_.emplace_back(); + auto& buffer_element = buffer_.back(); + TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status)); + if (buffer_element.status.ok()) { + size_t value_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("buffer[", i, "].size")), &temp)); + value_size = static_cast(temp); + } + buffer_element.value.reserve(value_size); + for (size_t j = 0; j < value_size; j++) { + buffer_element.value.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + strings::StrCat("buffer[", i, "][", j, "]"), + &buffer_element.value.back())); + } + } + } + return Status::OK(); + } + private: // A buffer element comprises a status and (if that status is // OK) a vector of tensors, representing an element of the input dataset. @@ -171,6 +250,12 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { } // 2. Read the next element. + // Acquire the parent lock since we will be reading an element + // from the input iterator. Note that we do not wish to release + // this lock till we have added the fetched element to the + // `buffer_` else there will be local state that may be missed + // by SaveInternal. + mutex_lock parent_l(parent_mu_); bool end_of_sequence; BufferElement buffer_element; buffer_element.status = input_impl_->GetNext( @@ -191,8 +276,50 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { } } + Status WriteStatus(IteratorStateWriter* writer, size_t index, + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatus(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name(strings::StrCat("status[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name(strings::StrCat("status[", index, "].error_message")); + } + + // This mutex is used to ensure exclusivity between multiple threads + // reading/writing this iterator's local state. mutex mu_; - const std::unique_ptr input_impl_; + // This mutex is used to ensure exclusivity between multiple threads + // accessing the parent iterator. We keep this separate from `mu_` to + // allow prefetching to run in parallel with GetNext calls. + mutex parent_mu_ ACQUIRED_BEFORE(mu_); + const std::unique_ptr input_impl_ GUARDED_BY(parent_mu_); condition_variable cond_var_; std::deque buffer_ GUARDED_BY(mu_); std::unique_ptr prefetch_thread_ GUARDED_BY(mu_);