Merge pull request #14922 from sb2nov/branch_176732156

Branch 176732156
This commit is contained in:
Sourabh Bajaj 2017-11-27 16:36:36 -08:00 committed by GitHub
commit 55055a23c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 778 additions and 236 deletions

View File

@ -106,6 +106,12 @@ class VSpace {
// Deletes the input tensor. // Deletes the input tensor.
virtual void DeleteGradient(Gradient* gradient) const = 0; virtual void DeleteGradient(Gradient* gradient) const = 0;
// Lets this VSpace know that it can release resources held by the
// `backward_function`, It will not be called again.
// `backward_function` must not be null.
virtual void ReleaseBackwardFunction(
BackwardFunction* backward_function) const = 0;
}; };
// Traces the execution of operations, doing eager garbage collection, and // Traces the execution of operations, doing eager garbage collection, and
@ -113,7 +119,11 @@ class VSpace {
template <typename Gradient, typename BackwardFunction> template <typename Gradient, typename BackwardFunction>
class GradientTape { class GradientTape {
public: public:
GradientTape() {} // If `persistent` is true, GradientTape will not eagerly delete backward
// functions (and hence the tensors they keep alive). Instead, everything
// is deleted in ~GradientTape. Persistent GradientTapes are useful when
// users want to compute multiple gradients over the same tape.
GradientTape(bool persistent) : persistent_(persistent) {}
~GradientTape() { ~GradientTape() {
for (const auto& pair : op_tape_) { for (const auto& pair : op_tape_) {
pair.second.backward_function_deleter(); pair.second.backward_function_deleter();
@ -150,6 +160,10 @@ class GradientTape {
// Map from tensor id to number of remaining usages (i.e. how many entries in // Map from tensor id to number of remaining usages (i.e. how many entries in
// the tape refer to it); to aid in tape garbage collection. // the tape refer to it); to aid in tape garbage collection.
std::unordered_map<int64, int64> tensor_usage_; std::unordered_map<int64, int64> tensor_usage_;
// If true, all activations are deleted in the first call to ComputeGradient.
// Else, only when this is destructed.
bool persistent_;
}; };
// Template instantiations here // Template instantiations here
@ -279,11 +293,16 @@ struct BackpropInitialState {
std::unordered_map<int64, int64> op_missing_tensor; std::unordered_map<int64, int64> op_missing_tensor;
}; };
// If `persistent_tape` is true, op_tape is not changed and none of the
// backwards functions are deleted.
// If `persistent_tape` is false, op_tape is cleared and backwards functions
// not needed for gradient computation are deleted. Backwards functions that
// are needed, are copied and returned in BackpropInitialState.
template <typename BackwardFunction> template <typename BackwardFunction>
BackpropInitialState<BackwardFunction> PrepareBackprop( BackpropInitialState<BackwardFunction> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape, gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
OpTape<BackwardFunction> op_tape, OpTape<BackwardFunction>* op_tape,
const std::unordered_set<int64>& sources_set) { const std::unordered_set<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack; std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size()); tensor_stack.reserve(target.size());
for (auto t : target) { for (auto t : target) {
@ -298,9 +317,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
continue; continue;
} }
int64 op_id = op_id_it->second; int64 op_id = op_id_it->second;
auto op_it = op_tape.find(op_id); auto op_it = op_tape->find(op_id);
auto result_op_it = result.op_tape.find(op_id); auto result_op_it = result.op_tape.find(op_id);
if (op_id == -1 || op_it == op_tape.end() || if (op_id == -1 || op_it == op_tape->end() ||
result_op_it != result.op_tape.end()) { result_op_it != result.op_tape.end()) {
continue; continue;
} }
@ -317,7 +336,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
} }
} }
} }
op_tape.erase(op_it); if (!persistent_tape) {
op_tape->erase(op_it);
}
} }
for (auto& pair : result.tensor_usage_counts) { for (auto& pair : result.tensor_usage_counts) {
auto it = tensor_tape.find(pair.first); auto it = tensor_tape.find(pair.first);
@ -325,9 +346,15 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
result.op_missing_tensor[it->second] += 1; result.op_missing_tensor[it->second] += 1;
} }
} }
// Call destructors for all unneeded gradient functions. if (!persistent_tape) {
for (const auto& op_pair : op_tape) { // Call destructors for all unneeded gradient functions and
op_pair.second.backward_function_deleter(); // clear the op_tape. We can clear the tape because ownership of
// backward functions that will be used for gradient computation
// has been transfered to `result`.
for (const auto& op_pair : *op_tape) {
op_pair.second.backward_function_deleter();
}
op_tape->clear();
} }
return result; return result;
} }
@ -369,7 +396,8 @@ Status InitialGradients(
auto op_it = op_tape.find(tensor_it->second); auto op_it = op_tape.find(tensor_it->second);
if (op_it == op_tape.end()) { if (op_it == op_tape.end()) {
return errors::Internal( return errors::Internal(
"Internal state of the gradient tape is invalid."); "Internal state of the gradient tape is invalid: "
"failed to find operation producing a tensor");
} }
bool found = false; bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
@ -383,7 +411,8 @@ Status InitialGradients(
} }
if (!found) { if (!found) {
return errors::Internal( return errors::Internal(
"Internal state of the gradient tape is invalid."); "Internal state of the gradient tape is invalid: "
"none of operations outputs match expected tensor");
} }
} else { } else {
// No record of the target tensor found on the tape, so no gradient // No record of the target tensor found on the tape, so no gradient
@ -415,17 +444,19 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
std::unordered_set<int64> sources_set(source_tensor_ids.begin(), std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end()); source_tensor_ids.end());
BackpropInitialState<BackwardFunction> state = PrepareBackprop( BackpropInitialState<BackwardFunction> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set); target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack = std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor); InitialStack(state.op_tape, state.op_missing_tensor);
std::unordered_map<int64, std::vector<Gradient*>> gradients; std::unordered_map<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
tensor_tape_, state.op_tape, tensor_tape_, state.op_tape,
state.tensor_usage_counts, &gradients); state.tensor_usage_counts, &gradients);
auto cleanup = [&state]() { auto cleanup = [this, &state]() {
// Release all backprop functions if (!persistent_) {
for (const auto& pair : state.op_tape) { // Release all backprop functions
pair.second.backward_function_deleter(); for (const auto& pair : state.op_tape) {
pair.second.backward_function_deleter();
}
} }
}; };
if (!s.ok()) { if (!s.ok()) {
@ -484,6 +515,9 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
std::vector<Gradient*> in_gradients; std::vector<Gradient*> in_gradients;
Status s = vspace.CallBackwardFunction(trace.backward_function, Status s = vspace.CallBackwardFunction(trace.backward_function,
out_gradients, &in_gradients); out_gradients, &in_gradients);
if (!persistent_) {
vspace.ReleaseBackwardFunction(trace.backward_function);
}
if (!s.ok()) { if (!s.ok()) {
cleanup(); cleanup();
return s; return s;

View File

@ -283,7 +283,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleSend(HloInstruction* send) override { Status HandleSend(HloInstruction* send) override {
TF_RET_CHECK(send->users().size() == 1); TF_RET_CHECK(send->users().size() == 1);
const HloInstruction* send_done = send->users()[0]; const HloInstruction* send_done = send->users().front();
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
return CheckShape( return CheckShape(
@ -301,7 +301,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleRecv(HloInstruction* recv) override { Status HandleRecv(HloInstruction* recv) override {
TF_RET_CHECK(recv->users().size() == 1); TF_RET_CHECK(recv->users().size() == 1);
const HloInstruction* recv_done = recv->users()[0]; const HloInstruction* recv_done = recv->users().front();
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
return CheckShape(recv, return CheckShape(recv,

View File

@ -94,6 +94,28 @@ PlatformUtil::GetSupportedPlatforms() {
platforms_string.c_str()); platforms_string.c_str());
} }
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
const string& platform_name) {
using tensorflow::str_util::Lowercase;
string platform_str = Lowercase(platform_name);
// "cpu" and "host" mean the same thing.
if (platform_str == "cpu") {
platform_str = "host";
}
// "gpu" and "cuda" mean the same thing.
if (platform_str == "gpu") {
platform_str = "cuda";
}
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
for (se::Platform* platform : platforms) {
if (Lowercase(platform->Name()) == platform_str) {
return platform;
}
}
return InvalidArgument("platform %s not found", platform_name.c_str());
}
// Returns whether the device underlying the given StreamExecutor is supported // Returns whether the device underlying the given StreamExecutor is supported
// by XLA. // by XLA.
static bool IsDeviceSupported(se::StreamExecutor* executor) { static bool IsDeviceSupported(se::StreamExecutor* executor) {

View File

@ -16,11 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
#include <string>
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
namespace xla { namespace xla {
@ -39,6 +42,11 @@ class PlatformUtil {
// default platform. Otherwise returns an error. // default platform. Otherwise returns an error.
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform(); static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
// Returns the platform according to the given name. Returns error if there is
// no such platform.
static StatusOr<perftools::gputools::Platform*> GetPlatform(
const string& platform_name);
// Returns a vector of StreamExecutors for the given platform. The vector is // Returns a vector of StreamExecutors for the given platform. The vector is
// indexed by device ordinal (device numbering used by StreamExecutor). If an // indexed by device ordinal (device numbering used by StreamExecutor). If an
// element is nullptr, then the device is present by not supported by XLA. // element is nullptr, then the device is present by not supported by XLA.

View File

@ -342,7 +342,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// //
// Careful: HloInstruction::operand_index returns the first index the // Careful: HloInstruction::operand_index returns the first index the
// operand appears in, but it may appear more than once! // operand appears in, but it may appear more than once!
if (user->user_count() == 1 && user->users()[0] == while_body_root && if (user->user_count() == 1 && user->users().front() == while_body_root &&
while_body_root->operand_index(user) == user->tuple_index() && while_body_root->operand_index(user) == user->tuple_index() &&
std::count(while_body_root->operands().begin(), std::count(while_body_root->operands().begin(),
while_body_root->operands().end(), user) == 1) { while_body_root->operands().end(), user) == 1) {
@ -444,7 +444,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// This is a GTE of an index that we've removed. Remove it from the // This is a GTE of an index that we've removed. Remove it from the
// cloned computation. // cloned computation.
CHECK(user->user_count() == 0 || CHECK(user->user_count() == 0 ||
user->user_count() == 1 && user->users()[0] == while_body_root) user->user_count() == 1 &&
user->users().front() == while_body_root)
<< "Instruction " << user->ToStringNoMetadata() << "Instruction " << user->ToStringNoMetadata()
<< " should be unused (except by root of while body), but has " << " should be unused (except by root of while body), but has "
"users: {" "users: {"

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import numpy as np import numpy as np
from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.boosted_trees.python.utils import losses
@ -60,35 +58,27 @@ class LossesTest(test_util.TensorFlowTestCase):
neg_loss = loss_for_negatives.eval() neg_loss = loss_for_negatives.eval()
# For positive labels, points <= 0.3 get max loss of e. # For positive labels, points <= 0.3 get max loss of e.
# For negative labels, these points have minimum loss of 1/e. # For negative labels, these points have minimum loss of 1/e.
for i in range(2): self.assertAllClose(np.exp(np.ones([2, 1])), pos_loss[:2], atol=1e-4)
self.assertAlmostEqual(math.exp(1), pos_loss[i], places=4) self.assertAllClose(np.exp(-np.ones([2, 1])), neg_loss[:2], atol=1e-4)
self.assertAlmostEqual(math.exp(-1), neg_loss[i], places=4)
# For positive lables, p oints with predictions 0.7 and larger get minimum # For positive lables, p oints with predictions 0.7 and larger get minimum
# loss value of 1/e. For negative labels, these points are wrongly # loss value of 1/e. For negative labels, these points are wrongly
# classified and get loss e. # classified and get loss e.
for i in range(6, 10): self.assertAllClose(np.exp(-np.ones([4, 1])), pos_loss[6:10], atol=1e-4)
self.assertAlmostEqual(math.exp(-1), pos_loss[i], places=4) self.assertAllClose(np.exp(np.ones([4, 1])), neg_loss[6:10], atol=1e-4)
self.assertAlmostEqual(math.exp(1), neg_loss[i], places=4)
# Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where # Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where
# y = 1/eps *x -1/(2eps), where x is the probability and label_m is either # y = 1/eps *x -1/(2eps), where x is the probability and label_m is either
# 1 or -1 (for label of 0). # 1 or -1 (for label of 0).
for i in range(2, 6): self.assertAllClose(
self.assertAlmostEqual( np.exp(-(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps)),
math.exp(-1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), pos_loss[2:6], atol=1e-4)
pos_loss[i], self.assertAllClose(
places=4) np.exp(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps),
self.assertAlmostEqual( neg_loss[2:6], atol=1e-4)
math.exp(1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)),
neg_loss[i],
places=4)
def test_per_example_squared_loss(self): def test_per_example_squared_loss(self):
def _squared_loss(p, y):
return np.mean(1.0 * (p - y) * (p - y))
labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32) labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32)
weights = array_ops.ones([5, 1], dtypes.float32) weights = array_ops.ones([5, 1], dtypes.float32)
predictions = np.array( predictions = np.array(
@ -99,9 +89,8 @@ class LossesTest(test_util.TensorFlowTestCase):
predictions) predictions)
loss = loss_tensor.eval() loss = loss_tensor.eval()
for i in range(5): self.assertAllClose(
self.assertAlmostEqual( np.square(labels[:5] - predictions[:5]), loss[:5], atol=1e-4)
_squared_loss(labels[i], predictions[i]), loss[i], places=4)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -80,13 +80,9 @@ class TPUClusterResolver(ClusterResolver):
raise ImportError('googleapiclient must be installed before using the ' raise ImportError('googleapiclient must be installed before using the '
'TPU cluster resolver') 'TPU cluster resolver')
# TODO(b/67375680): Remove custom URL once TPU APIs are finalized
self._service = discovery.build( self._service = discovery.build(
'tpu', 'tpu', 'v1alpha1',
'v1', credentials=self._credentials)
credentials=self._credentials,
discoveryServiceUrl='https://storage.googleapis.com'
'/tpu-api-definition/v1alpha1.json')
else: else:
self._service = service self._service = service

View File

@ -305,6 +305,18 @@ py_test(
], ],
) )
py_test(
name = "prefetch_dataset_op_test",
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
":dataset_serialization_test",
"//tensorflow/python:platform",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test( py_test(
name = "range_dataset_op_test", name = "range_dataset_op_test",
size = "small", size = "small",
@ -333,7 +345,7 @@ py_test(
py_test( py_test(
name = "reader_dataset_ops_test", name = "reader_dataset_ops_test",
size = "small", size = "medium",
srcs = ["reader_dataset_ops_test.py"], srcs = ["reader_dataset_ops_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
tags = ["no_pip"], tags = ["no_pip"],

View File

@ -0,0 +1,39 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
class PrefetchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def build_dataset(self, seed):
return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
buffer_size=10, seed=seed, reshuffle_each_iteration=False)
def testCore(self):
num_outputs = 100
self.run_core_tests(lambda: self.build_dataset(10),
lambda: self.build_dataset(20), num_outputs)
if __name__ == "__main__":
test.main()

View File

@ -151,16 +151,24 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
return self._fisher_est.damping return self._fisher_est.damping
def minimize(self, *args, **kwargs): def minimize(self, *args, **kwargs):
kwargs["var_list"] = kwargs.get("var_list") or self.variables
if "var_list" not in kwargs:
kwargs["var_list"] = tf_variables.trainable_variables()
if set(kwargs["var_list"]) != set(self.variables): if set(kwargs["var_list"]) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating " raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.") "variables.")
return super(KfacOptimizer, self).minimize(*args, **kwargs) return super(KfacOptimizer, self).minimize(*args, **kwargs)
def compute_gradients(self, *args, **kwargs):
# args[1] could be our var_list
if len(args) > 1:
var_list = args[1]
else:
kwargs["var_list"] = kwargs.get("var_list") or self.variables
var_list = kwargs["var_list"]
if set(var_list) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
def apply_gradients(self, grads_and_vars, *args, **kwargs): def apply_gradients(self, grads_and_vars, *args, **kwargs):
"""Applies gradients to variables. """Applies gradients to variables.

View File

@ -312,6 +312,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":graph_optimizer", ":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:devices",
@ -320,6 +321,7 @@ cc_library(
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/costs:virtual_placer",
"//tensorflow/core/grappler/utils:frame", "//tensorflow/core/grappler/utils:frame",
], ],
) )

View File

@ -27,7 +27,9 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
@ -109,11 +111,13 @@ bool IsMaxPoolGradV1(const NodeDef& node) {
class GraphProcessor { class GraphProcessor {
public: public:
GraphProcessor(GraphDef* graph, NodeMap* node_map, GraphProcessor(const VirtualPlacer& virtual_placer,
const std::unordered_set<string>& nodes_to_preserve) const std::unordered_set<string>& nodes_to_preserve,
: graph_(graph), GraphDef* graph, NodeMap* node_map)
node_map_(node_map), : virtual_placer_(virtual_placer),
nodes_to_preserve_(nodes_to_preserve) {} nodes_to_preserve_(nodes_to_preserve),
graph_(graph),
node_map_(node_map) {}
protected: protected:
NodeDef* AddNodePermConst(const string& name, const string& device, NodeDef* AddNodePermConst(const string& name, const string& device,
@ -122,7 +126,6 @@ class GraphProcessor {
node_map_->AddNode(name, node); node_map_->AddNode(name, node);
node->set_name(name); node->set_name(name);
node->set_op("Const"); node->set_op("Const");
node->set_device(device);
AttrValue attr_data_type; AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32); attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type}); node->mutable_attr()->insert({"dtype", attr_data_type});
@ -133,6 +136,13 @@ class GraphProcessor {
} }
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
node->mutable_attr()->insert({"value", attr_tensor}); node->mutable_attr()->insert({"value", attr_tensor});
string device_name;
if (device.empty()) {
device_name = virtual_placer_.get_canonical_device_name(*node);
} else {
device_name = device;
}
node->set_device(device_name);
return node; return node;
} }
@ -142,7 +152,6 @@ class GraphProcessor {
node_map_->AddNode(name, node); node_map_->AddNode(name, node);
node->set_name(name); node->set_name(name);
node->set_op("Const"); node->set_op("Const");
node->set_device(device);
AttrValue attr_data_type; AttrValue attr_data_type;
attr_data_type.set_type(dtype); attr_data_type.set_type(dtype);
node->mutable_attr()->insert({"dtype", attr_data_type}); node->mutable_attr()->insert({"dtype", attr_data_type});
@ -151,6 +160,13 @@ class GraphProcessor {
tensor.scalar<int>()() = value; tensor.scalar<int>()() = value;
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
node->mutable_attr()->insert({"value", attr_tensor}); node->mutable_attr()->insert({"value", attr_tensor});
string device_name;
if (device.empty()) {
device_name = virtual_placer_.get_canonical_device_name(*node);
} else {
device_name = device;
}
node->set_device(device_name);
return node; return node;
} }
@ -159,7 +175,6 @@ class GraphProcessor {
node_map_->AddNode(name, node); node_map_->AddNode(name, node);
node->set_name(name); node->set_name(name);
node->set_op("Const"); node->set_op("Const");
node->set_device(device);
AttrValue attr_data_type; AttrValue attr_data_type;
attr_data_type.set_type(DT_INT32); attr_data_type.set_type(DT_INT32);
node->mutable_attr()->insert({"dtype", attr_data_type}); node->mutable_attr()->insert({"dtype", attr_data_type});
@ -172,26 +187,37 @@ class GraphProcessor {
} }
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
node->mutable_attr()->insert({"value", attr_tensor}); node->mutable_attr()->insert({"value", attr_tensor});
string device_name;
if (device.empty()) {
device_name = virtual_placer_.get_canonical_device_name(*node);
} else {
device_name = device;
}
node->set_device(device_name);
return node; return node;
} }
const VirtualPlacer& virtual_placer_;
const std::unordered_set<string>& nodes_to_preserve_;
GraphDef* graph_; GraphDef* graph_;
NodeMap* node_map_; NodeMap* node_map_;
const std::unordered_set<string>& nodes_to_preserve_;
}; };
struct OptimizeContext { struct OptimizeContext {
OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map, OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
const VirtualPlacer& virtual_placer,
const std::unordered_set<string>& nodes_to_preserve, const std::unordered_set<string>& nodes_to_preserve,
bool is_in_frame) bool is_in_frame)
: graph(graph), : graph(graph),
node(node), node(node),
node_map(node_map), node_map(node_map),
virtual_placer(virtual_placer),
nodes_to_preserve(nodes_to_preserve), nodes_to_preserve(nodes_to_preserve),
is_in_frame(is_in_frame) {} is_in_frame(is_in_frame) {}
GraphDef* graph; GraphDef* graph;
NodeDef* node; NodeDef* node;
NodeMap* node_map; NodeMap* node_map;
const VirtualPlacer& virtual_placer;
const std::unordered_set<string>& nodes_to_preserve; const std::unordered_set<string>& nodes_to_preserve;
bool is_in_frame; bool is_in_frame;
}; };
@ -199,8 +225,8 @@ struct OptimizeContext {
class NodeProcessor : public GraphProcessor { class NodeProcessor : public GraphProcessor {
public: public:
explicit NodeProcessor(const OptimizeContext& opt_cxt) explicit NodeProcessor(const OptimizeContext& opt_cxt)
: GraphProcessor(opt_cxt.graph, opt_cxt.node_map, : GraphProcessor(opt_cxt.virtual_placer, opt_cxt.nodes_to_preserve,
opt_cxt.nodes_to_preserve), opt_cxt.graph, opt_cxt.node_map),
node_(opt_cxt.node), node_(opt_cxt.node),
is_in_frame_(opt_cxt.is_in_frame) {} is_in_frame_(opt_cxt.is_in_frame) {}
virtual ~NodeProcessor() {} virtual ~NodeProcessor() {}
@ -257,7 +283,25 @@ class NodeProcessor : public GraphProcessor {
} }
virtual bool ShouldProcess() const { virtual bool ShouldProcess() const {
return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs(); return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
IsOnGPU();
}
virtual bool IsOnGPU() const {
string device_name;
if (node_->device().empty()) {
device_name = virtual_placer_.get_canonical_device_name(*node_);
} else {
device_name = node_->device();
}
string device;
string not_used;
if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
(StringPiece(str_util::Lowercase(device)))
.contains(str_util::Lowercase(DEVICE_GPU))) {
return true;
}
return false;
} }
void UpdateAttrDataFormat() { void UpdateAttrDataFormat() {
@ -536,6 +580,9 @@ class BiasAddGradProcessor : public NodeProcessor {
if (MustPreserve()) { if (MustPreserve()) {
return false; return false;
} }
if (!IsOnGPU()) {
return false;
}
auto input = node_map_->GetNode(node_->input(0)); auto input = node_map_->GetNode(node_->input(0));
if (input) { if (input) {
if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) { if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) {
@ -556,7 +603,7 @@ class Conv2DProcessor : public NodeProcessor {
protected: protected:
bool ShouldProcess() const override { bool ShouldProcess() const override {
return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() && return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
(!IsGemmUsed() || no_gemm_); (!IsGemmUsed() || no_gemm_) && IsOnGPU();
} }
TensorShapeProto GetShape(const string& input_name) const { TensorShapeProto GetShape(const string& input_name) const {
@ -667,10 +714,24 @@ class FusedBatchNormGradProcessor : public NodeProcessor {
: NodeProcessor(opt_cxt) {} : NodeProcessor(opt_cxt) {}
protected: protected:
bool ShouldProcess() const override {
return NodeProcessor::ShouldProcess() && IsTraining();
}
std::vector<int> GetInputPos() const override { std::vector<int> GetInputPos() const override {
std::vector<int> input_pos = {0, 1}; std::vector<int> input_pos = {0, 1};
return input_pos; return input_pos;
} }
private:
bool IsTraining() const {
if (node_->attr().find("is_training") != node_->attr().end()) {
if (node_->attr().at("is_training").b()) {
return true;
}
}
return false;
}
}; };
class MaxPoolGradProcessor : public NodeProcessor { class MaxPoolGradProcessor : public NodeProcessor {
@ -693,7 +754,7 @@ class AgnosticNodeProcessor : public NodeProcessor {
protected: protected:
bool ShouldProcess() const override { bool ShouldProcess() const override {
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
IsNodeAfterNCHWToNHWC(); IsNodeAfterNCHWToNHWC() && IsOnGPU();
} }
bool IsNodeAfterNCHWToNHWC() const { bool IsNodeAfterNCHWToNHWC() const {
@ -746,7 +807,8 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() && IsNodeAfterNCHWToNHWC() &&
(Is4DOperateWithND(4) || Is4DOperateWithScalar() || (Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
Is4DOperateWithVector()); Is4DOperateWithVector()) &&
IsOnGPU();
} }
std::vector<int> GetInputPos() const override { std::vector<int> GetInputPos() const override {
@ -855,7 +917,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
protected: protected:
bool ShouldProcess() const override { bool ShouldProcess() const override {
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() && IsAlongDimC(); IsNodeAfterNCHWToNHWC() && IsAlongDimC() && IsOnGPU();
} }
std::vector<int> GetInputPos() const override { std::vector<int> GetInputPos() const override {
@ -920,7 +982,7 @@ class PadProcessor : public AgnosticNodeProcessor {
protected: protected:
bool ShouldProcess() const override { bool ShouldProcess() const override {
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() && return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() && PaddingSupported(); IsNodeAfterNCHWToNHWC() && PaddingSupported() && IsOnGPU();
} }
Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); } Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
@ -1132,7 +1194,8 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
protected: protected:
bool ShouldProcess() const override { bool ShouldProcess() const override {
return !MustPreserve() && IsDimsN(*node_, 2) && HasOutputs() && return !MustPreserve() && IsDimsN(*node_, 2) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW(); IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() &&
IsOnGPU();
} }
Status AddLayoutTransposeToOutputs() override { return Status::OK(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
@ -1183,7 +1246,7 @@ class SumProcessor : public AgnosticNodeProcessor {
auto input0 = node_map_->GetNode(node_->input(0)); auto input0 = node_map_->GetNode(node_->input(0));
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
(IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) && (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
IsAlongDimNHW(); IsAlongDimNHW() && IsOnGPU();
} }
Status AddLayoutTransposeToOutputs() override { return Status::OK(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
@ -1243,42 +1306,41 @@ class SumProcessor : public AgnosticNodeProcessor {
class DataLayoutOptimizer : GraphProcessor { class DataLayoutOptimizer : GraphProcessor {
public: public:
explicit DataLayoutOptimizer( explicit DataLayoutOptimizer(
LayoutOptimizer::TuningConfig config, const VirtualPlacer& virtual_placer,
const std::unordered_set<string>& nodes_to_preserve, const LayoutOptimizer::TuningConfig& config,
const string& default_device, GraphDef* graph, NodeMap* node_map) const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
: GraphProcessor(graph, node_map, nodes_to_preserve), NodeMap* node_map)
config_(config), : GraphProcessor(virtual_placer, nodes_to_preserve, graph, node_map),
default_device_(default_device) {} config_(config) {}
Status Optimize() { Status Optimize() {
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size(); VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
TF_RETURN_IF_ERROR(Expand()); TF_RETURN_IF_ERROR(Expand());
LOG(INFO) << "Number of nodes after Expand: " << graph_->node_size(); VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
TF_RETURN_IF_ERROR(Collapse()); TF_RETURN_IF_ERROR(Collapse());
LOG(INFO) << "Number of nodes after Collapse: " << graph_->node_size(); VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
return Status::OK(); return Status::OK();
} }
private: private:
NodeDef* AddNodePermNHWCToNCHW() { NodeDef* AddNodePermNHWCToNCHW() {
return AddNodePermConst(kPermNHWCToNCHW, default_device_, {0, 3, 1, 2}); return AddNodePermConst(kPermNHWCToNCHW, "", {0, 3, 1, 2});
} }
NodeDef* AddNodePermNCHWToNHWC() { NodeDef* AddNodePermNCHWToNHWC() {
return AddNodePermConst(kPermNCHWToNHWC, default_device_, {0, 2, 3, 1}); return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1});
} }
NodeDef* AddNodeConcatConst() { NodeDef* AddNodeConcatConst() {
return AddNodeConstScalar(kConcatConst, default_device_, DT_INT32, 1); return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1);
} }
NodeDef* AddNodeGatherAxisConst() { NodeDef* AddNodeGatherAxisConst() {
return AddNodeConstScalar(kGatherAxisConst, default_device_, DT_INT32, 0); return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0);
} }
NodeDef* AddNodeReductionConst() { NodeDef* AddNodeReductionConst() {
return GraphProcessor::AddNodeReductionConst(kReductionConst, return GraphProcessor::AddNodeReductionConst(kReductionConst, "");
default_device_);
} }
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
@ -1295,8 +1357,8 @@ class DataLayoutOptimizer : GraphProcessor {
ops_format_supported.end()) { ops_format_supported.end()) {
auto node = graph_->mutable_node(i); auto node = graph_->mutable_node(i);
bool is_in_frame = !frames[node].empty(); bool is_in_frame = !frames[node].empty();
OptimizeContext opt_cxt(graph_, node, node_map_, nodes_to_preserve_, OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
is_in_frame); nodes_to_preserve_, is_in_frame);
std::unique_ptr<NodeProcessor> node_processor; std::unique_ptr<NodeProcessor> node_processor;
if (IsAvgPoolGrad(*node)) { if (IsAvgPoolGrad(*node)) {
node_processor.reset(new AvgPoolGradProcessor(opt_cxt)); node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
@ -1343,8 +1405,8 @@ class DataLayoutOptimizer : GraphProcessor {
ops_format_agnostic.end()) { ops_format_agnostic.end()) {
auto node = graph_->mutable_node(i); auto node = graph_->mutable_node(i);
bool is_in_frame = !frames[node].empty(); bool is_in_frame = !frames[node].empty();
OptimizeContext opt_cxt(graph_, node, node_map_, nodes_to_preserve_, OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
is_in_frame); nodes_to_preserve_, is_in_frame);
std::unique_ptr<NodeProcessor> node_processor; std::unique_ptr<NodeProcessor> node_processor;
if (IsAddN(*node)) { if (IsAddN(*node)) {
node_processor.reset(new AddNProcessor(opt_cxt)); node_processor.reset(new AddNProcessor(opt_cxt));
@ -1419,8 +1481,7 @@ class DataLayoutOptimizer : GraphProcessor {
return Status::OK(); return Status::OK();
} }
LayoutOptimizer::TuningConfig config_; const LayoutOptimizer::TuningConfig& config_;
string default_device_;
}; };
int GetNumTranspose(const GraphDef& graph) { int GetNumTranspose(const GraphDef& graph) {
@ -1430,7 +1491,7 @@ int GetNumTranspose(const GraphDef& graph) {
number++; number++;
} }
} }
LOG(INFO) << "Number of Transpose nodes: " << number; VLOG(1) << "Number of Transpose nodes: " << number;
return number; return number;
} }
@ -1455,7 +1516,6 @@ int GetNumGPUs(const Cluster& cluster) {
Status LayoutOptimizer::Tune(const GrapplerItem& item, Status LayoutOptimizer::Tune(const GrapplerItem& item,
const GraphProperties& graph_properties, const GraphProperties& graph_properties,
const string& default_device,
const TuningConfig& config, GraphDef* output) { const TuningConfig& config, GraphDef* output) {
auto status = graph_properties.AnnotateOutputShapes(output); auto status = graph_properties.AnnotateOutputShapes(output);
if (!status.ok()) { if (!status.ok()) {
@ -1463,8 +1523,8 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item,
return status; return status;
} }
NodeMap node_map(output); NodeMap node_map(output);
DataLayoutOptimizer layout_optimizer(config, nodes_to_preserve_, DataLayoutOptimizer layout_optimizer(*virtual_placer_, config,
default_device, output, &node_map); nodes_to_preserve_, output, &node_map);
status = layout_optimizer.Optimize(); status = layout_optimizer.Optimize();
return status; return status;
} }
@ -1477,6 +1537,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
return Status::OK(); return Status::OK();
} }
virtual_placer_.reset(new VirtualPlacer(cluster));
nodes_to_preserve_ = item.NodesToPreserve(); nodes_to_preserve_ = item.NodesToPreserve();
GraphProperties graph_properties(item); GraphProperties graph_properties(item);
auto status = graph_properties.InferStatically(); auto status = graph_properties.InferStatically();
@ -1487,20 +1548,13 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TuningConfig config; TuningConfig config;
config.no_gemm = false; config.no_gemm = false;
string default_device = "/job:localhost/replica:0/task:0/cpu:0"; status = Tune(item, graph_properties, config, output);
if (cluster) {
if (!cluster->GetDevices().empty()) {
default_device = cluster->GetDevices().begin()->first;
}
}
status = Tune(item, graph_properties, default_device, config, output);
// This is based on an empirical observation that if the introduced Transpose // This is based on an empirical observation that if the introduced Transpose
// nodes is more than 30, not using GEMM implementation would result in better // nodes is more than 30, not using GEMM implementation would result in better
// performance. // performance.
if (status.ok() && GetNumTranspose(*output) > 30) { if (status.ok() && GetNumTranspose(*output) > 30) {
config.no_gemm = true; config.no_gemm = true;
status = Tune(item, graph_properties, default_device, config, output); status = Tune(item, graph_properties, config, output);
} }
if (!status.ok()) { if (!status.ok()) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
namespace tensorflow { namespace tensorflow {
@ -47,10 +48,10 @@ class LayoutOptimizer : public GraphOptimizer {
const GraphDef& optimize_output, double result) override; const GraphDef& optimize_output, double result) override;
private: private:
std::unique_ptr<VirtualPlacer> virtual_placer_;
std::unordered_set<string> nodes_to_preserve_; std::unordered_set<string> nodes_to_preserve_;
Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties, Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties,
const string& default_device, const TuningConfig& config, const TuningConfig& config, GraphDef* output);
GraphDef* output);
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -39,6 +39,11 @@ class LayoutOptimizerTest : public ::testing::Test {
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
const string& padding) { const string& padding) {
return SimpleConv2D(s, input_size, filter_size, padding, "");
}
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
const string& padding, const string& device) {
int batch_size = 128; int batch_size = 128;
int input_height = input_size; int input_height = input_size;
int input_width = input_size; int input_width = input_size;
@ -59,8 +64,8 @@ class LayoutOptimizerTest : public ::testing::Test {
Output filter = Output filter =
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter, Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
{1, stride, stride, 1}, padding); filter, {1, stride, stride, 1}, padding);
return conv; return conv;
} }
@ -109,6 +114,36 @@ class LayoutOptimizerTest : public ::testing::Test {
return tensor; return tensor;
} }
Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
int batch_size = 16;
int input_height = 8;
int input_width = 8;
int input_channels = 3;
TensorShape shape({batch_size, input_height, input_width, input_channels});
Tensor data(DT_FLOAT, shape);
test::FillIota<float>(&data, 1.0f);
Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
Output y_backprop =
ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
TensorShape shape_vector({input_channels});
Tensor data_vector(DT_FLOAT, shape_vector);
test::FillIota<float>(&data_vector, 2.0f);
Output scale =
ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
Output reserve1 =
ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
Output reserve2 =
ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
ops::FusedBatchNormGrad::Attrs attrs;
attrs.is_training_ = is_training;
auto output =
ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
x, scale, reserve1, reserve2, attrs);
return output.x_backprop;
}
std::unique_ptr<VirtualCluster> virtual_cluster_; std::unique_ptr<VirtualCluster> virtual_cluster_;
}; };
@ -278,6 +313,92 @@ TEST_F(LayoutOptimizerTest, PreserveFetch) {
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC"); EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
} }
TEST_F(LayoutOptimizerTest, EmptyDevice) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 3, 2, "VALID");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto conv_node = node_map.GetNode("Conv2D");
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
}
TEST_F(LayoutOptimizerTest, GPUDevice) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv =
SimpleConv2D(&s, 3, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto conv_node = node_map.GetNode("Conv2D");
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
}
TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv =
SimpleConv2D(&s, 3, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto conv_node = node_map.GetNode("Conv2D");
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
}
TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 3, 2, "VALID", "/CPU:0");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto conv_node = node_map.GetNode("Conv2D");
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
}
TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto conv_node = node_map.GetNode("FusedBatchNormGrad");
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
}
TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
auto conv_node = node_map.GetNode("FusedBatchNormGrad");
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -6059,6 +6059,7 @@ tf_kernel_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
], ],
) )

View File

@ -79,16 +79,20 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx, Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override { bool* end_of_sequence) override {
if (!input_impl_) { {
*end_of_sequence = true; tf_shared_lock l(mu_);
return Status::OK(); if (!input_impl_) {
} *end_of_sequence = true;
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); return Status::OK();
while (!s.ok()) { }
out_tensors->clear(); Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); while (!s.ok()) {
out_tensors->clear();
s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
} }
if (*end_of_sequence) { if (*end_of_sequence) {
mutex_lock l(mu_);
input_impl_.reset(); input_impl_.reset();
} }
return Status::OK(); return Status::OK();
@ -96,6 +100,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
protected: protected:
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (input_impl_) if (input_impl_)
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
else else
@ -106,6 +111,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(OpKernelContext* ctx, Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (reader->Contains(full_name("input_impls_empty"))) if (reader->Contains(full_name("input_impls_empty")))
input_impl_.reset(); input_impl_.reset();
else else
@ -114,7 +120,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
} }
private: private:
std::unique_ptr<IteratorBase> input_impl_; mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
}; };
const DatasetBase* const input_; const DatasetBase* const input_;

View File

@ -14,9 +14,10 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <deque> #include <deque>
#include "tensorflow/core/kernels/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/dataset.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
namespace tensorflow { namespace tensorflow {
@ -39,14 +40,14 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES(ctx, buffer_size > 0, OP_REQUIRES(ctx, buffer_size > 0,
errors::InvalidArgument("buffer_size must be > 0")); errors::InvalidArgument("buffer_size must be > 0"));
*output = new Dataset(input, buffer_size); *output = new Dataset(ctx, input, buffer_size);
} }
private: private:
class Dataset : public DatasetBase { class Dataset : public GraphDatasetBase {
public: public:
Dataset(const DatasetBase* input, int64 buffer_size) Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size)
: input_(input), buffer_size_(buffer_size) { : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
input_->Ref(); input_->Ref();
} }
@ -67,6 +68,18 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "PrefetchDatasetOp::Dataset"; } string DebugString() override { return "PrefetchDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* buffer_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node, buffer_size}, output));
return Status::OK();
}
private: private:
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
@ -121,7 +134,10 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
// Wake the prefetch thread, in case it has been waiting // Wake the prefetch thread, in case it has been waiting
// for space in the buffer. // for space in the buffer.
cond_var_.notify_one(); // Also wake up threads from other calls to GetNext.
// TODO(mrry): Consider using different condition variables
// for GetNext and Prefetch.
cond_var_.notify_all();
return s; return s;
} else if (prefetch_thread_finished_) { } else if (prefetch_thread_finished_) {
*end_of_sequence = true; *end_of_sequence = true;
@ -130,6 +146,69 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
} }
} }
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
// Acquire both locks to ensure that the prefetch thread and
// all GetNext threads are blocked.
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
for (size_t i = 0; i < buffer_.size(); i++) {
auto& buffer_element = buffer_[i];
TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
if (buffer_element.status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("buffer[", i, "].size")),
buffer_element.value.size()));
for (size_t j = 0; j < buffer_element.value.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
strings::StrCat("buffer[", i, "][", j, "]"),
buffer_element.value[j]));
}
}
}
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
IteratorStateReader* reader) override {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
buffer_.clear();
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
size_t buffer_size;
{
int64 temp;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("buffer_size"), &temp));
buffer_size = static_cast<size_t>(temp);
}
for (size_t i = 0; i < buffer_size; i++) {
buffer_.emplace_back();
auto& buffer_element = buffer_.back();
TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
if (buffer_element.status.ok()) {
size_t value_size;
{
int64 temp;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("buffer[", i, "].size")), &temp));
value_size = static_cast<size_t>(temp);
}
buffer_element.value.reserve(value_size);
for (size_t j = 0; j < value_size; j++) {
buffer_element.value.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
strings::StrCat("buffer[", i, "][", j, "]"),
&buffer_element.value.back()));
}
}
}
return Status::OK();
}
private: private:
// A buffer element comprises a status and (if that status is // A buffer element comprises a status and (if that status is
// OK) a vector of tensors, representing an element of the input dataset. // OK) a vector of tensors, representing an element of the input dataset.
@ -173,6 +252,12 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
} }
// 2. Read the next element. // 2. Read the next element.
// Acquire the parent lock since we will be reading an element
// from the input iterator. Note that we do not wish to release
// this lock till we have added the fetched element to the
// `buffer_` else there will be local state that may be missed
// by SaveInternal.
mutex_lock parent_l(parent_mu_);
bool end_of_sequence; bool end_of_sequence;
BufferElement buffer_element; BufferElement buffer_element;
buffer_element.status = input_impl_->GetNext( buffer_element.status = input_impl_->GetNext(
@ -193,8 +278,50 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
} }
} }
Status WriteStatus(IteratorStateWriter* writer, size_t index,
const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
status.error_message()));
}
return Status::OK();
}
Status ReadStatus(IteratorStateReader* reader, size_t index,
Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
if (code != error::Code::OK) {
string error_message;
TF_RETURN_IF_ERROR(
reader->ReadScalar(ErrorMessageKey(index), &error_message));
*status = Status(code, error_message);
} else {
*status = Status::OK();
}
return Status::OK();
}
string CodeKey(size_t index) {
return full_name(strings::StrCat("status[", index, "].code"));
}
string ErrorMessageKey(size_t index) {
return full_name(strings::StrCat("status[", index, "].error_message"));
}
// This mutex is used to ensure exclusivity between multiple threads
// reading/writing this iterator's local state.
mutex mu_; mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_; // This mutex is used to ensure exclusivity between multiple threads
// accessing the parent iterator. We keep this separate from `mu_` to
// allow prefetching to run in parallel with GetNext calls.
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_; condition_variable cond_var_;
std::deque<BufferElement> buffer_ GUARDED_BY(mu_); std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);

View File

@ -4352,7 +4352,10 @@ py_library(
], ],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":pywrap_tensorflow_internal"], deps = [
":pywrap_tensorflow_internal",
":tf_cluster",
],
) )
py_test( py_test(

View File

@ -798,13 +798,41 @@ class GradientTape(object):
grad = g.gradient(y, [x])[0] grad = g.gradient(y, [x])[0]
assert grad.numpy() == 6.0 assert grad.numpy() == 6.0
``` ```
By default, the resources held by a GradientTape are released as soon as
GradientTape.gradient() method is called. However, if one need to compute
multiple gradients over the same computation, she can create a persistent
GradientTape. Persistent tapes allow multiple calls to the gradient() method
and release resources when the tape object is destructed.
Example usage:
```python
with tfe.GradientTape(persistent=True) as g:
x = tf.constant(3.0)
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, [x])[0]
assert dz_dx.numpy() == 108.0 # 4*x^3 at x = 3
dy_dx = g.gradient(y, [x])[0]
assert dy_dx.numpy() == 6.0
del g # Drop the reference to the tape
""" """
def __init__(self): def __init__(self, persistent=False):
"""Creates a new GradientTape.
Args:
persistent: Boolean controlling whether a persistent gradient tape
is created. Must be True or False.
"""
self._tape = None self._tape = None
self._persistent = persistent
def __enter__(self): def __enter__(self):
tape.push_new_tape() tape.push_new_tape(persistent=self._persistent)
return self return self
def __exit__(self, typ, value, traceback): def __exit__(self, typ, value, traceback):
@ -838,12 +866,14 @@ class GradientTape(object):
than once. than once.
""" """
if self._tape is None: if self._tape is None:
raise RuntimeError("GradientTape.gradient can only be called once, and " raise RuntimeError("GradientTape.gradient can only be called once "
"on non-persistent tapes, and "
"only when the context manager has exited.") "only when the context manager has exited.")
sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable) sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
else x else x
for x in sources] for x in sources]
grad = imperative_grad.imperative_grad( grad = imperative_grad.imperative_grad(
_default_vspace, self._tape, [target], sources) _default_vspace, self._tape, [target], sources)
self._tape = None if not self._persistent:
self._tape = None
return grad return grad

View File

@ -314,6 +314,37 @@ class BackpropTest(test.TestCase):
RuntimeError, 'GradientTape.gradient can only be called once'): RuntimeError, 'GradientTape.gradient can only be called once'):
g.gradient(y, [x]) g.gradient(y, [x])
def testPersistentTape(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, [x])[0]
self.assertEqual(dz_dx.numpy(), 4*3*3*3)
dy_dx = g.gradient(y, [x])[0]
self.assertEqual(dy_dx.numpy(), 2*3)
del g
def testPersistentNestedTape(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
g.watch(x)
y = x * x
with backprop.GradientTape(persistent=True) as gg:
gg.watch(y)
z = 2 * y
for _ in range(2):
inner_grad = gg.gradient(z, [y])[0]
self.assertEqual(inner_grad.numpy(), 2.0)
y += inner_grad
del gg
grad = g.gradient(y, [x])[0]
self.assertEqual(grad.numpy(), 6.0)
grad = g.gradient(z, [x])[0]
self.assertEqual(grad.numpy(), 12.0)
del g
def testGradientTapeVariable(self): def testGradientTapeVariable(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v') v = resource_variable_ops.ResourceVariable(1.0, name='v')
with backprop.GradientTape() as g: with backprop.GradientTape() as g:

View File

@ -88,7 +88,8 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// Pushes a new tape into the thread-local stack. // Pushes a new tape into the thread-local stack.
void TFE_Py_TapeStackPushNew(); // `persistent` must be a PyBool_Type, i.e either Py_True or Py_False
void TFE_Py_TapeStackPushNew(PyObject* persistent);
// Pops the tape from the top of the stack and returns it. // Pops the tape from the top of the stack and returns it.
PyObject* TFE_Py_TapeStackPop(); PyObject* TFE_Py_TapeStackPop();

View File

@ -469,7 +469,8 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
class GradientTape class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyObject> { : public tensorflow::eager::GradientTape<PyObject, PyObject> {
public: public:
GradientTape() {} explicit GradientTape(bool persistent)
: tensorflow::eager::GradientTape<PyObject, PyObject>(persistent) {}
void WatchVariable(PyObject* v) { void WatchVariable(PyObject* v) {
watched_variables_.insert(v); watched_variables_.insert(v);
@ -557,11 +558,11 @@ std::vector<TFE_Py_Tape*>* GetTapeStack() {
} }
#endif #endif
void TFE_Py_TapeStackPushNew() { void TFE_Py_TapeStackPushNew(PyObject* persistent) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew; TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
tape->tape = new GradientTape(); tape->tape = new GradientTape(persistent == Py_True);
GetTapeStack()->push_back(tape); GetTapeStack()->push_back(tape);
} }
@ -704,6 +705,7 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
list.push_back(FastTensorId(tensor)); list.push_back(FastTensorId(tensor));
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
Py_DECREF(seq);
return list; return list;
} }
} }
@ -889,7 +891,6 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
PyObject* py_result = PyEval_CallObject( PyObject* py_result = PyEval_CallObject(
reinterpret_cast<PyObject*>(backward_function), grads); reinterpret_cast<PyObject*>(backward_function), grads);
Py_DECREF(grads); Py_DECREF(grads);
Py_DECREF(backward_function);
if (py_result == nullptr) { if (py_result == nullptr) {
return tensorflow::errors::Internal("gradient function threw exceptions"); return tensorflow::errors::Internal("gradient function threw exceptions");
} }
@ -917,6 +918,10 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
void ReleaseBackwardFunction(PyObject* backward_function) const final {
Py_DECREF(backward_function);
}
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
private: private:

View File

@ -33,9 +33,9 @@ class Tape(object):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
def push_new_tape(): def push_new_tape(persistent=False):
"""Pushes a new tape onto the tape stack.""" """Pushes a new tape onto the tape stack."""
pywrap_tensorflow.TFE_Py_TapeStackPushNew() pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent)
def watch(tensor): def watch(tensor):

View File

@ -14,6 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
%include "tensorflow/python/platform/base.i" %include "tensorflow/python/platform/base.i"
%include <std_shared_ptr.i>
%include "item.i"
// Wrap the cluster into an object that swig can manipulate. This ensures it will call the object
// destructor upon garbage collection instead of leaking memory.
struct GCluster {
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
};
%{ %{
#include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/protobuf/device_properties.pb.h"
@ -72,6 +80,7 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
} }
%{ %{
#include <memory>
#include <vector> #include <vector>
#include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/clusters/single_machine.h"
@ -82,39 +91,56 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/protobuf/device_properties.pb.h"
static tensorflow::grappler::Cluster* TF_NewCluster( // Provide the implementation of the GCluster struct here.
bool allow_soft_placement, struct GCluster {
bool disable_detailed_stats, TF_Status* out_status) { GCluster() {}
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {}
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();;
tensorflow::grappler::Cluster* operator->() const {
return cluster_.get();
}
tensorflow::grappler::Cluster* get() const {
return cluster_.get();
}
bool is_none() const {
return cluster_.get() == nullptr;
}
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
};
static GCluster TF_NewCluster(bool allow_soft_placement,
bool disable_detailed_stats, TF_Status* out_status) {
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
int timeout_s = 60 * 10; int timeout_s = 60 * 10;
tensorflow::grappler::Cluster* cluster = tensorflow::grappler::Cluster* cluster_ =
new tensorflow::grappler::SingleMachine( new tensorflow::grappler::SingleMachine(
timeout_s, num_cpu_cores, num_gpus); timeout_s, num_cpu_cores, num_gpus);
cluster->DisableDetailedStats(disable_detailed_stats); cluster_->DisableDetailedStats(disable_detailed_stats);
cluster->AllowSoftPlacement(allow_soft_placement); cluster_->AllowSoftPlacement(allow_soft_placement);
tensorflow::Status status = cluster->Provision(); tensorflow::Status status = cluster_->Provision();
tensorflow::Set_TF_Status_from_Status(out_status, status); tensorflow::Set_TF_Status_from_Status(out_status, status);
return cluster; return GCluster(cluster_);
} }
static tensorflow::grappler::Cluster* TF_NewVirtualCluster( static GCluster TF_NewVirtualCluster(
const std::vector<tensorflow::NamedDevice>& named_devices, const std::vector<tensorflow::NamedDevice>& named_devices,
TF_Status* out_status) { TF_Status* out_status) {
std::unordered_map<string, tensorflow::DeviceProperties> devices; std::unordered_map<string, tensorflow::DeviceProperties> devices;
for (const auto& named_device : named_devices) { for (const auto& named_device : named_devices) {
devices[named_device.name()]= named_device.properties(); devices[named_device.name()]= named_device.properties();
} }
tensorflow::grappler::Cluster* cluster = tensorflow::grappler::Cluster*cluster_ =
new tensorflow::grappler::VirtualCluster(devices); new tensorflow::grappler::VirtualCluster(devices);
tensorflow::Status status = cluster->Provision(); tensorflow::Status status = cluster_->Provision();
tensorflow::Set_TF_Status_from_Status(out_status, status); tensorflow::Set_TF_Status_from_Status(out_status, status);
return cluster; return GCluster(cluster_);
} }
static void TF_DeleteCluster(tensorflow::grappler::Cluster* cluster) { static void TF_ShutdownCluster(GCluster cluster) {
cluster->Shutdown(); cluster->Shutdown();
delete cluster;
} }
tensorflow::Status _GetOpPerformanceDataAndRunTime( tensorflow::Status _GetOpPerformanceDataAndRunTime(
@ -136,8 +162,9 @@ tensorflow::Status _GetOpPerformanceDataAndRunTime(
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) { static PyObject* TF_ListDevices(GCluster cluster) {
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices(); const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyList_New(devices.size()); PyObject* result = PyList_New(devices.size());
int i = 0; int i = 0;
for (auto& dev : devices) { for (auto& dev : devices) {
@ -150,17 +177,18 @@ static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) {
PyList_SetItem(result, i, dev_obj); PyList_SetItem(result, i, dev_obj);
++i; ++i;
} }
PyGILState_Release(gstate);
return result; return result;
} }
static PyObject* TF_MeasureCosts( static PyObject* TF_MeasureCosts(
const tensorflow::grappler::GrapplerItem* item, GItem item,
tensorflow::grappler::Cluster* cluster, GCluster cluster,
bool generate_timeline, TF_Status* out_status) { bool generate_timeline, TF_Status* out_status) {
tensorflow::OpPerformanceList op_performance_data; tensorflow::OpPerformanceList op_performance_data;
tensorflow::StepStats step_stats; tensorflow::StepStats step_stats;
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster, 10, 0); tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), 10, 0);
tensorflow::grappler::Costs costs; tensorflow::grappler::Costs costs;
tensorflow::Status status = _GetOpPerformanceDataAndRunTime( tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
@ -184,6 +212,7 @@ static PyObject* TF_MeasureCosts(
if (!status.ok()) { if (!status.ok()) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* op_perf_objs = PyList_New( PyObject* op_perf_objs = PyList_New(
op_performance_data.op_performance_size()); op_performance_data.op_performance_size());
for (int i = 0; i < op_performance_data.op_performance_size(); i++) { for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
@ -211,17 +240,19 @@ static PyObject* TF_MeasureCosts(
status = tensorflow::Status(tensorflow::error::Code::INTERNAL, status = tensorflow::Status(tensorflow::error::Code::INTERNAL,
"Error setting return tuples."); "Error setting return tuples.");
tensorflow::Set_TF_Status_from_Status(out_status, status); tensorflow::Set_TF_Status_from_Status(out_status, status);
Py_RETURN_NONE; Py_INCREF(Py_None);
ret = Py_None;
} }
PyGILState_Release(gstate);
return ret; return ret;
} }
static PyObject* TF_DeterminePeakMemoryUsage( static PyObject* TF_DeterminePeakMemoryUsage(
const tensorflow::grappler::GrapplerItem* item, GItem item,
tensorflow::grappler::Cluster* cluster, GCluster cluster,
TF_Status* out_status) { TF_Status* out_status) {
if (!item || !cluster) { if (item.is_none() || cluster.is_none()) {
tensorflow::Status status(tensorflow::error::Code::INTERNAL, tensorflow::Status status(tensorflow::error::Code::INTERNAL,
"You need both a cluster and an item to determine peak memory usage"); "You need both a cluster and an item to determine peak memory usage");
tensorflow::Set_TF_Status_from_Status(out_status, status); tensorflow::Set_TF_Status_from_Status(out_status, status);
@ -231,7 +262,7 @@ static PyObject* TF_DeterminePeakMemoryUsage(
tensorflow::Status status; tensorflow::Status status;
if (cluster->DetailedStatsEnabled()) { if (cluster->DetailedStatsEnabled()) {
status = memory.InferDynamically(cluster); status = memory.InferDynamically(cluster.get());
} else { } else {
status = memory.InferStatically(cluster->GetDevices()); status = memory.InferStatically(cluster->GetDevices());
} }
@ -240,6 +271,7 @@ static PyObject* TF_DeterminePeakMemoryUsage(
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* result = PyDict_New(); PyObject* result = PyDict_New();
for (const auto& device : cluster->GetDevices()) { for (const auto& device : cluster->GetDevices()) {
const tensorflow::grappler::GraphMemory::MemoryUsage& usage = const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
@ -261,24 +293,24 @@ static PyObject* TF_DeterminePeakMemoryUsage(
PyTuple_SetItem(ret, 1, per_device); PyTuple_SetItem(ret, 1, per_device);
PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret); PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
} }
PyGILState_Release(gstate);
return result; return result;
} }
%} %}
// Wrap these functions. // Wrap these functions.
static GCluster TF_NewCluster(
static tensorflow::grappler::Cluster* TF_NewCluster(
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status); bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status);
static tensorflow::grappler::Cluster* TF_NewVirtualCluster( static GCluster TF_NewVirtualCluster(
const std::vector<tensorflow::NamedDevice>& named_devices, const std::vector<tensorflow::NamedDevice>& named_devices,
TF_Status* out_status); TF_Status* out_status);
static void TF_DeleteCluster(tensorflow::grappler::Cluster* cluster); static void TF_ShutdownCluster(GCluster cluster);
static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster); static PyObject* TF_ListDevices(GCluster cluster);
static PyObject* TF_MeasureCosts( static PyObject* TF_MeasureCosts(
const tensorflow::grappler::GrapplerItem* item, tensorflow::grappler::Cluster* cluster, GItem item, GCluster cluster,
bool generate_timeline, TF_Status* out_status); bool generate_timeline, TF_Status* out_status);
static PyObject* TF_DeterminePeakMemoryUsage( static PyObject* TF_DeterminePeakMemoryUsage(
const tensorflow::grappler::GrapplerItem* item, tensorflow::grappler::Cluster* cluster, GItem item, GCluster cluster,
TF_Status* out_status); TF_Status* out_status);

View File

@ -46,6 +46,7 @@ class Cluster(object):
the local machine. the local machine.
""" """
self._tf_cluster = None self._tf_cluster = None
self._generate_timeline = not disable_timeline
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
if devices is None: if devices is None:
self._tf_cluster = tf_cluster.TF_NewCluster( self._tf_cluster = tf_cluster.TF_NewCluster(
@ -54,11 +55,10 @@ class Cluster(object):
devices_serialized = [device.SerializeToString() for device in devices] devices_serialized = [device.SerializeToString() for device in devices]
self._tf_cluster = tf_cluster.TF_NewVirtualCluster( self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
devices_serialized, status) devices_serialized, status)
self._generate_timeline = not disable_timeline
def __del__(self): def __del__(self):
if self._tf_cluster is not None: if self._tf_cluster is not None:
tf_cluster.TF_DeleteCluster(self._tf_cluster) tf_cluster.TF_ShutdownCluster(self._tf_cluster)
@property @property
def tf_cluster(self): def tf_cluster(self):

View File

@ -15,6 +15,7 @@ limitations under the License.
%include "tensorflow/python/lib/core/strings.i" %include "tensorflow/python/lib/core/strings.i"
%include "tensorflow/python/platform/base.i" %include "tensorflow/python/platform/base.i"
%include "cluster.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string; char* c_string;
@ -42,8 +43,8 @@ limitations under the License.
%} %}
%{ %{
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
per_node_report, tensorflow::grappler::Cluster* cluster) { GCluster cluster) {
tensorflow::grappler::ItemConfig cfg; tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false; cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item = std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
@ -53,7 +54,7 @@ per_node_report, tensorflow::grappler::Cluster* cluster) {
} }
string suffix; string suffix;
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster, suffix); tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix);
std::stringstream os; std::stringstream os;
analyzer.GenerateReport(os, per_node_report); analyzer.GenerateReport(os, per_node_report);
@ -62,5 +63,5 @@ per_node_report, tensorflow::grappler::Cluster* cluster) {
%} %}
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
per_node_report, tensorflow::grappler::Cluster* cluster); GCluster cluster);

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
%include <std_shared_ptr.i>
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string; char* c_string;
Py_ssize_t py_size; Py_ssize_t py_size;
@ -30,7 +31,12 @@ limitations under the License.
$1 = &temp; $1 = &temp;
} }
%newobject TF_NewItem; // Wrap the item into an object that swig can manipulate. This ensures it will call the object
// destructor upon garbage collection instead of leaking memory.
struct GItem {
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
};
%{ %{
#include <unordered_set> #include <unordered_set>
@ -42,8 +48,26 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
static tensorflow::grappler::GrapplerItem* TF_NewItem( // Provide the implementation fo the GItem struct here.
struct GItem {
GItem() {}
GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {}
tensorflow::grappler::GrapplerItem* operator->() const {
return item_.get();
}
const tensorflow::grappler::GrapplerItem& operator*() const {
return *item_.get();
}
bool is_none() const {
return item_.get() == nullptr;
}
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
};
static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* out_status) { bool ignore_user_placement, TF_Status* out_status) {
if (meta_graph.collection_def().count("train_op") == 0) { if (meta_graph.collection_def().count("train_op") == 0) {
@ -65,11 +89,11 @@ static tensorflow::grappler::GrapplerItem* TF_NewItem(
return nullptr; return nullptr;
} }
tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK()); tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
return item.release(); return GItem(item.release());
} }
static std::vector<string> TF_IdentifyImportantOps(const tensorflow::grappler::GrapplerItem* item) { static std::vector<string> TF_IdentifyImportantOps(GItem item) {
if (!item) { if (item.is_none()) {
return {}; return {};
} }
@ -91,8 +115,8 @@ static std::vector<string> TF_IdentifyImportantOps(const tensorflow::grappler::G
return ops; return ops;
} }
static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* item) { static PyObject* TF_GetOpProperties(GItem item) {
if (!item) { if (item.is_none()) {
Py_RETURN_NONE; Py_RETURN_NONE;
} }
tensorflow::grappler::GraphProperties properties(*item); tensorflow::grappler::GraphProperties properties(*item);
@ -101,6 +125,7 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it
Py_RETURN_NONE; Py_RETURN_NONE;
} }
PyGILState_STATE gstate = PyGILState_Ensure();
PyObject* props = PyDict_New(); PyObject* props = PyDict_New();
for (const auto& node : item->graph.node()) { for (const auto& node : item->graph.node()) {
const string& node_name = node.name(); const string& node_name = node.name();
@ -115,8 +140,8 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it
PyList_SetItem(prop, i, output_prop); PyList_SetItem(prop, i, output_prop);
} }
CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop)); CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
} }
PyGILState_Release(gstate);
return props; return props;
} }
@ -124,8 +149,8 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it
// Wrap these functions. // Wrap these functions.
static tensorflow::grappler::GrapplerItem* TF_NewItem( static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* out_status); bool ignore_user_placement, TF_Status* out_status);
static std::vector<string> TF_IdentifyImportantOps(const tensorflow::grappler::GrapplerItem* item); static std::vector<string> TF_IdentifyImportantOps(GItem item);
static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* item); static PyObject* TF_GetOpProperties(GItem item);

View File

@ -15,6 +15,7 @@ limitations under the License.
%include "tensorflow/python/platform/base.i" %include "tensorflow/python/platform/base.i"
%include "cluster.i"
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
char* c_string; char* c_string;
@ -92,7 +93,7 @@ void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* dev
} }
PyObject* TF_OptimizeGraph( PyObject* TF_OptimizeGraph(
tensorflow::grappler::Cluster* cluster, GCluster cluster,
const tensorflow::RewriterConfig& rewriter_config, const tensorflow::RewriterConfig& rewriter_config,
const tensorflow::MetaGraphDef& metagraph, const tensorflow::MetaGraphDef& metagraph,
bool verbose, const string& graph_id, TF_Status* out_status) { bool verbose, const string& graph_id, TF_Status* out_status) {
@ -102,17 +103,10 @@ PyObject* TF_OptimizeGraph(
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
std::unique_ptr<tensorflow::grappler::VirtualCluster> virtual_cluster;
if (cluster == nullptr) {
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
DetectDevices(&device_map);
virtual_cluster.reset(new tensorflow::grappler::VirtualCluster(device_map));
cluster = virtual_cluster.get();
}
tensorflow::DeviceBase* cpu_device = nullptr; tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::GraphDef out_graph; tensorflow::GraphDef out_graph;
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config); tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config);
tensorflow::Status status = optimizer.Optimize(cluster, *grappler_item, &out_graph); tensorflow::Status status = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
if (verbose) { if (verbose) {
optimizer.PrintResult(); optimizer.PrintResult();
} }
@ -127,7 +121,7 @@ PyObject* TF_OptimizeGraph(
// Wrap this function // Wrap this function
PyObject* TF_OptimizeGraph( PyObject* TF_OptimizeGraph(
tensorflow::grappler::Cluster* cluster, GCluster cluster,
const tensorflow::RewriterConfig& rewriter_config, const tensorflow::RewriterConfig& rewriter_config,
const tensorflow::MetaGraphDef& metagraph, bool verbose, const tensorflow::MetaGraphDef& metagraph, bool verbose,
const string& graph_id, TF_Status* out_status); const string& graph_id, TF_Status* out_status);

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.python import pywrap_tensorflow as tf_opt from tensorflow.python import pywrap_tensorflow as tf_opt
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.grappler import cluster as gcluster
def OptimizeGraph(rewriter_config, def OptimizeGraph(rewriter_config,
@ -30,8 +31,9 @@ def OptimizeGraph(rewriter_config,
cluster=None): cluster=None):
"""Optimize the provided metagraph.""" """Optimize the provided metagraph."""
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
ret_from_swig = tf_opt.TF_OptimizeGraph(None if cluster is None else if cluster is None:
cluster.tf_cluster, cluster = gcluster.Cluster()
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
rewriter_config.SerializeToString(), rewriter_config.SerializeToString(),
metagraph.SerializeToString(), metagraph.SerializeToString(),
verbose, graph_id, status) verbose, graph_id, status)

View File

@ -565,33 +565,34 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
if not _like_rnncell(cell): if not _like_rnncell(cell):
raise TypeError("cell must be an instance of RNNCell") raise TypeError("cell must be an instance of RNNCell")
# By default, time_major==False and inputs are batch-major: shaped
# [batch, time, depth]
# For internal calculations, we transpose to [time, batch, depth]
flat_input = nest.flatten(inputs)
if not time_major:
# (B,T,D) => (T,B,D)
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length.get_shape().ndims not in (None, 1):
raise ValueError(
"sequence_length must be a vector of length batch_size, "
"but saw shape: %s" % sequence_length.get_shape())
sequence_length = array_ops.identity( # Just to find it in the graph.
sequence_length, name="sequence_length")
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope: with vs.variable_scope(scope or "rnn") as varscope:
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
if context.in_graph_mode(): if context.in_graph_mode():
if varscope.caching_device is None: if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device) varscope.set_caching_device(lambda op: op.device)
# By default, time_major==False and inputs are batch-major: shaped
# [batch, time, depth]
# For internal calculations, we transpose to [time, batch, depth]
flat_input = nest.flatten(inputs)
if not time_major:
# (B,T,D) => (T,B,D)
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
parallel_iterations = parallel_iterations or 32
if sequence_length is not None:
sequence_length = math_ops.to_int32(sequence_length)
if sequence_length.get_shape().ndims not in (None, 1):
raise ValueError(
"sequence_length must be a vector of length batch_size, "
"but saw shape: %s" % sequence_length.get_shape())
sequence_length = array_ops.identity( # Just to find it in the graph.
sequence_length, name="sequence_length")
batch_size = _best_effort_input_batch_size(flat_input) batch_size = _best_effort_input_batch_size(flat_input)
if initial_state is not None: if initial_state is not None:

View File

@ -108,11 +108,6 @@ static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase *gpu_mem) {
return AsCudaDevicePtr(*gpu_mem); return AsCudaDevicePtr(*gpu_mem);
} }
static CudaContext* GetCudaContext(Stream *stream) {
return static_cast<CUDAExecutor *>(stream->parent()->implementation())
->cuda_context();
}
CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) { CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) {
CHECK(cuda_exec != nullptr); CHECK(cuda_exec != nullptr);
return cuda_exec->cuda_context(); return cuda_exec->cuda_context();
@ -380,11 +375,11 @@ bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
void **kernel_params = const_cast<void **>(args.argument_addresses().data()); void **kernel_params = const_cast<void **>(args.argument_addresses().data());
if (!CUDADriver::LaunchKernel(GetCudaContext(stream), cufunc, block_dims.x, if (!CUDADriver::LaunchKernel(context_, cufunc, block_dims.x, block_dims.y,
block_dims.y, block_dims.z, thread_dims.x, block_dims.z, thread_dims.x, thread_dims.y,
thread_dims.y, thread_dims.z, thread_dims.z, args.number_of_shared_bytes(),
args.number_of_shared_bytes(), custream, custream, kernel_params,
kernel_params, nullptr /* = extra */)) { nullptr /* = extra */)) {
LOG(ERROR) << "failed to launch CUDA kernel with args: " LOG(ERROR) << "failed to launch CUDA kernel with args: "
<< args.number_of_arguments() << args.number_of_arguments()
<< "; thread dim: " << thread_dims.ToString() << "; thread dim: " << thread_dims.ToString()