Merge pull request #14922 from sb2nov/branch_176732156
Branch 176732156
This commit is contained in:
commit
55055a23c8
@ -106,6 +106,12 @@ class VSpace {
|
|||||||
|
|
||||||
// Deletes the input tensor.
|
// Deletes the input tensor.
|
||||||
virtual void DeleteGradient(Gradient* gradient) const = 0;
|
virtual void DeleteGradient(Gradient* gradient) const = 0;
|
||||||
|
|
||||||
|
// Lets this VSpace know that it can release resources held by the
|
||||||
|
// `backward_function`, It will not be called again.
|
||||||
|
// `backward_function` must not be null.
|
||||||
|
virtual void ReleaseBackwardFunction(
|
||||||
|
BackwardFunction* backward_function) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Traces the execution of operations, doing eager garbage collection, and
|
// Traces the execution of operations, doing eager garbage collection, and
|
||||||
@ -113,7 +119,11 @@ class VSpace {
|
|||||||
template <typename Gradient, typename BackwardFunction>
|
template <typename Gradient, typename BackwardFunction>
|
||||||
class GradientTape {
|
class GradientTape {
|
||||||
public:
|
public:
|
||||||
GradientTape() {}
|
// If `persistent` is true, GradientTape will not eagerly delete backward
|
||||||
|
// functions (and hence the tensors they keep alive). Instead, everything
|
||||||
|
// is deleted in ~GradientTape. Persistent GradientTapes are useful when
|
||||||
|
// users want to compute multiple gradients over the same tape.
|
||||||
|
GradientTape(bool persistent) : persistent_(persistent) {}
|
||||||
~GradientTape() {
|
~GradientTape() {
|
||||||
for (const auto& pair : op_tape_) {
|
for (const auto& pair : op_tape_) {
|
||||||
pair.second.backward_function_deleter();
|
pair.second.backward_function_deleter();
|
||||||
@ -150,6 +160,10 @@ class GradientTape {
|
|||||||
// Map from tensor id to number of remaining usages (i.e. how many entries in
|
// Map from tensor id to number of remaining usages (i.e. how many entries in
|
||||||
// the tape refer to it); to aid in tape garbage collection.
|
// the tape refer to it); to aid in tape garbage collection.
|
||||||
std::unordered_map<int64, int64> tensor_usage_;
|
std::unordered_map<int64, int64> tensor_usage_;
|
||||||
|
|
||||||
|
// If true, all activations are deleted in the first call to ComputeGradient.
|
||||||
|
// Else, only when this is destructed.
|
||||||
|
bool persistent_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Template instantiations here
|
// Template instantiations here
|
||||||
@ -279,11 +293,16 @@ struct BackpropInitialState {
|
|||||||
std::unordered_map<int64, int64> op_missing_tensor;
|
std::unordered_map<int64, int64> op_missing_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// If `persistent_tape` is true, op_tape is not changed and none of the
|
||||||
|
// backwards functions are deleted.
|
||||||
|
// If `persistent_tape` is false, op_tape is cleared and backwards functions
|
||||||
|
// not needed for gradient computation are deleted. Backwards functions that
|
||||||
|
// are needed, are copied and returned in BackpropInitialState.
|
||||||
template <typename BackwardFunction>
|
template <typename BackwardFunction>
|
||||||
BackpropInitialState<BackwardFunction> PrepareBackprop(
|
BackpropInitialState<BackwardFunction> PrepareBackprop(
|
||||||
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
||||||
OpTape<BackwardFunction> op_tape,
|
OpTape<BackwardFunction>* op_tape,
|
||||||
const std::unordered_set<int64>& sources_set) {
|
const std::unordered_set<int64>& sources_set, bool persistent_tape) {
|
||||||
std::vector<int64> tensor_stack;
|
std::vector<int64> tensor_stack;
|
||||||
tensor_stack.reserve(target.size());
|
tensor_stack.reserve(target.size());
|
||||||
for (auto t : target) {
|
for (auto t : target) {
|
||||||
@ -298,9 +317,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
int64 op_id = op_id_it->second;
|
int64 op_id = op_id_it->second;
|
||||||
auto op_it = op_tape.find(op_id);
|
auto op_it = op_tape->find(op_id);
|
||||||
auto result_op_it = result.op_tape.find(op_id);
|
auto result_op_it = result.op_tape.find(op_id);
|
||||||
if (op_id == -1 || op_it == op_tape.end() ||
|
if (op_id == -1 || op_it == op_tape->end() ||
|
||||||
result_op_it != result.op_tape.end()) {
|
result_op_it != result.op_tape.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -317,7 +336,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
op_tape.erase(op_it);
|
if (!persistent_tape) {
|
||||||
|
op_tape->erase(op_it);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto& pair : result.tensor_usage_counts) {
|
for (auto& pair : result.tensor_usage_counts) {
|
||||||
auto it = tensor_tape.find(pair.first);
|
auto it = tensor_tape.find(pair.first);
|
||||||
@ -325,9 +346,15 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
|
|||||||
result.op_missing_tensor[it->second] += 1;
|
result.op_missing_tensor[it->second] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Call destructors for all unneeded gradient functions.
|
if (!persistent_tape) {
|
||||||
for (const auto& op_pair : op_tape) {
|
// Call destructors for all unneeded gradient functions and
|
||||||
op_pair.second.backward_function_deleter();
|
// clear the op_tape. We can clear the tape because ownership of
|
||||||
|
// backward functions that will be used for gradient computation
|
||||||
|
// has been transfered to `result`.
|
||||||
|
for (const auto& op_pair : *op_tape) {
|
||||||
|
op_pair.second.backward_function_deleter();
|
||||||
|
}
|
||||||
|
op_tape->clear();
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -369,7 +396,8 @@ Status InitialGradients(
|
|||||||
auto op_it = op_tape.find(tensor_it->second);
|
auto op_it = op_tape.find(tensor_it->second);
|
||||||
if (op_it == op_tape.end()) {
|
if (op_it == op_tape.end()) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Internal state of the gradient tape is invalid.");
|
"Internal state of the gradient tape is invalid: "
|
||||||
|
"failed to find operation producing a tensor");
|
||||||
}
|
}
|
||||||
bool found = false;
|
bool found = false;
|
||||||
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
|
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
|
||||||
@ -383,7 +411,8 @@ Status InitialGradients(
|
|||||||
}
|
}
|
||||||
if (!found) {
|
if (!found) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Internal state of the gradient tape is invalid.");
|
"Internal state of the gradient tape is invalid: "
|
||||||
|
"none of operations outputs match expected tensor");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No record of the target tensor found on the tape, so no gradient
|
// No record of the target tensor found on the tape, so no gradient
|
||||||
@ -415,17 +444,19 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
|||||||
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
||||||
source_tensor_ids.end());
|
source_tensor_ids.end());
|
||||||
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
|
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
|
||||||
target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set);
|
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
|
||||||
std::vector<int64> op_stack =
|
std::vector<int64> op_stack =
|
||||||
InitialStack(state.op_tape, state.op_missing_tensor);
|
InitialStack(state.op_tape, state.op_missing_tensor);
|
||||||
std::unordered_map<int64, std::vector<Gradient*>> gradients;
|
std::unordered_map<int64, std::vector<Gradient*>> gradients;
|
||||||
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
|
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
|
||||||
tensor_tape_, state.op_tape,
|
tensor_tape_, state.op_tape,
|
||||||
state.tensor_usage_counts, &gradients);
|
state.tensor_usage_counts, &gradients);
|
||||||
auto cleanup = [&state]() {
|
auto cleanup = [this, &state]() {
|
||||||
// Release all backprop functions
|
if (!persistent_) {
|
||||||
for (const auto& pair : state.op_tape) {
|
// Release all backprop functions
|
||||||
pair.second.backward_function_deleter();
|
for (const auto& pair : state.op_tape) {
|
||||||
|
pair.second.backward_function_deleter();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
@ -484,6 +515,9 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
|||||||
std::vector<Gradient*> in_gradients;
|
std::vector<Gradient*> in_gradients;
|
||||||
Status s = vspace.CallBackwardFunction(trace.backward_function,
|
Status s = vspace.CallBackwardFunction(trace.backward_function,
|
||||||
out_gradients, &in_gradients);
|
out_gradients, &in_gradients);
|
||||||
|
if (!persistent_) {
|
||||||
|
vspace.ReleaseBackwardFunction(trace.backward_function);
|
||||||
|
}
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
cleanup();
|
cleanup();
|
||||||
return s;
|
return s;
|
||||||
|
@ -283,7 +283,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
|||||||
|
|
||||||
Status HandleSend(HloInstruction* send) override {
|
Status HandleSend(HloInstruction* send) override {
|
||||||
TF_RET_CHECK(send->users().size() == 1);
|
TF_RET_CHECK(send->users().size() == 1);
|
||||||
const HloInstruction* send_done = send->users()[0];
|
const HloInstruction* send_done = send->users().front();
|
||||||
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
|
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
|
||||||
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
|
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
|
||||||
return CheckShape(
|
return CheckShape(
|
||||||
@ -301,7 +301,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
|||||||
|
|
||||||
Status HandleRecv(HloInstruction* recv) override {
|
Status HandleRecv(HloInstruction* recv) override {
|
||||||
TF_RET_CHECK(recv->users().size() == 1);
|
TF_RET_CHECK(recv->users().size() == 1);
|
||||||
const HloInstruction* recv_done = recv->users()[0];
|
const HloInstruction* recv_done = recv->users().front();
|
||||||
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
|
||||||
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
|
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
|
||||||
return CheckShape(recv,
|
return CheckShape(recv,
|
||||||
|
@ -94,6 +94,28 @@ PlatformUtil::GetSupportedPlatforms() {
|
|||||||
platforms_string.c_str());
|
platforms_string.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform(
|
||||||
|
const string& platform_name) {
|
||||||
|
using tensorflow::str_util::Lowercase;
|
||||||
|
string platform_str = Lowercase(platform_name);
|
||||||
|
// "cpu" and "host" mean the same thing.
|
||||||
|
if (platform_str == "cpu") {
|
||||||
|
platform_str = "host";
|
||||||
|
}
|
||||||
|
// "gpu" and "cuda" mean the same thing.
|
||||||
|
if (platform_str == "gpu") {
|
||||||
|
platform_str = "cuda";
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms());
|
||||||
|
for (se::Platform* platform : platforms) {
|
||||||
|
if (Lowercase(platform->Name()) == platform_str) {
|
||||||
|
return platform;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return InvalidArgument("platform %s not found", platform_name.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
// Returns whether the device underlying the given StreamExecutor is supported
|
// Returns whether the device underlying the given StreamExecutor is supported
|
||||||
// by XLA.
|
// by XLA.
|
||||||
static bool IsDeviceSupported(se::StreamExecutor* executor) {
|
static bool IsDeviceSupported(se::StreamExecutor* executor) {
|
||||||
|
@ -16,11 +16,14 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -39,6 +42,11 @@ class PlatformUtil {
|
|||||||
// default platform. Otherwise returns an error.
|
// default platform. Otherwise returns an error.
|
||||||
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
|
static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
|
||||||
|
|
||||||
|
// Returns the platform according to the given name. Returns error if there is
|
||||||
|
// no such platform.
|
||||||
|
static StatusOr<perftools::gputools::Platform*> GetPlatform(
|
||||||
|
const string& platform_name);
|
||||||
|
|
||||||
// Returns a vector of StreamExecutors for the given platform. The vector is
|
// Returns a vector of StreamExecutors for the given platform. The vector is
|
||||||
// indexed by device ordinal (device numbering used by StreamExecutor). If an
|
// indexed by device ordinal (device numbering used by StreamExecutor). If an
|
||||||
// element is nullptr, then the device is present by not supported by XLA.
|
// element is nullptr, then the device is present by not supported by XLA.
|
||||||
|
@ -342,7 +342,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
|||||||
//
|
//
|
||||||
// Careful: HloInstruction::operand_index returns the first index the
|
// Careful: HloInstruction::operand_index returns the first index the
|
||||||
// operand appears in, but it may appear more than once!
|
// operand appears in, but it may appear more than once!
|
||||||
if (user->user_count() == 1 && user->users()[0] == while_body_root &&
|
if (user->user_count() == 1 && user->users().front() == while_body_root &&
|
||||||
while_body_root->operand_index(user) == user->tuple_index() &&
|
while_body_root->operand_index(user) == user->tuple_index() &&
|
||||||
std::count(while_body_root->operands().begin(),
|
std::count(while_body_root->operands().begin(),
|
||||||
while_body_root->operands().end(), user) == 1) {
|
while_body_root->operands().end(), user) == 1) {
|
||||||
@ -444,7 +444,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
|||||||
// This is a GTE of an index that we've removed. Remove it from the
|
// This is a GTE of an index that we've removed. Remove it from the
|
||||||
// cloned computation.
|
// cloned computation.
|
||||||
CHECK(user->user_count() == 0 ||
|
CHECK(user->user_count() == 0 ||
|
||||||
user->user_count() == 1 && user->users()[0] == while_body_root)
|
user->user_count() == 1 &&
|
||||||
|
user->users().front() == while_body_root)
|
||||||
<< "Instruction " << user->ToStringNoMetadata()
|
<< "Instruction " << user->ToStringNoMetadata()
|
||||||
<< " should be unused (except by root of while body), but has "
|
<< " should be unused (except by root of while body), but has "
|
||||||
"users: {"
|
"users: {"
|
||||||
|
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.contrib.boosted_trees.python.utils import losses
|
from tensorflow.contrib.boosted_trees.python.utils import losses
|
||||||
@ -60,35 +58,27 @@ class LossesTest(test_util.TensorFlowTestCase):
|
|||||||
neg_loss = loss_for_negatives.eval()
|
neg_loss = loss_for_negatives.eval()
|
||||||
# For positive labels, points <= 0.3 get max loss of e.
|
# For positive labels, points <= 0.3 get max loss of e.
|
||||||
# For negative labels, these points have minimum loss of 1/e.
|
# For negative labels, these points have minimum loss of 1/e.
|
||||||
for i in range(2):
|
self.assertAllClose(np.exp(np.ones([2, 1])), pos_loss[:2], atol=1e-4)
|
||||||
self.assertAlmostEqual(math.exp(1), pos_loss[i], places=4)
|
self.assertAllClose(np.exp(-np.ones([2, 1])), neg_loss[:2], atol=1e-4)
|
||||||
self.assertAlmostEqual(math.exp(-1), neg_loss[i], places=4)
|
|
||||||
|
|
||||||
# For positive lables, p oints with predictions 0.7 and larger get minimum
|
# For positive lables, p oints with predictions 0.7 and larger get minimum
|
||||||
# loss value of 1/e. For negative labels, these points are wrongly
|
# loss value of 1/e. For negative labels, these points are wrongly
|
||||||
# classified and get loss e.
|
# classified and get loss e.
|
||||||
for i in range(6, 10):
|
self.assertAllClose(np.exp(-np.ones([4, 1])), pos_loss[6:10], atol=1e-4)
|
||||||
self.assertAlmostEqual(math.exp(-1), pos_loss[i], places=4)
|
self.assertAllClose(np.exp(np.ones([4, 1])), neg_loss[6:10], atol=1e-4)
|
||||||
self.assertAlmostEqual(math.exp(1), neg_loss[i], places=4)
|
|
||||||
|
|
||||||
# Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where
|
# Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where
|
||||||
# y = 1/eps *x -1/(2eps), where x is the probability and label_m is either
|
# y = 1/eps *x -1/(2eps), where x is the probability and label_m is either
|
||||||
# 1 or -1 (for label of 0).
|
# 1 or -1 (for label of 0).
|
||||||
for i in range(2, 6):
|
self.assertAllClose(
|
||||||
self.assertAlmostEqual(
|
np.exp(-(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps)),
|
||||||
math.exp(-1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)),
|
pos_loss[2:6], atol=1e-4)
|
||||||
pos_loss[i],
|
self.assertAllClose(
|
||||||
places=4)
|
np.exp(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps),
|
||||||
self.assertAlmostEqual(
|
neg_loss[2:6], atol=1e-4)
|
||||||
math.exp(1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)),
|
|
||||||
neg_loss[i],
|
|
||||||
places=4)
|
|
||||||
|
|
||||||
def test_per_example_squared_loss(self):
|
def test_per_example_squared_loss(self):
|
||||||
|
|
||||||
def _squared_loss(p, y):
|
|
||||||
return np.mean(1.0 * (p - y) * (p - y))
|
|
||||||
|
|
||||||
labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32)
|
labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32)
|
||||||
weights = array_ops.ones([5, 1], dtypes.float32)
|
weights = array_ops.ones([5, 1], dtypes.float32)
|
||||||
predictions = np.array(
|
predictions = np.array(
|
||||||
@ -99,9 +89,8 @@ class LossesTest(test_util.TensorFlowTestCase):
|
|||||||
predictions)
|
predictions)
|
||||||
|
|
||||||
loss = loss_tensor.eval()
|
loss = loss_tensor.eval()
|
||||||
for i in range(5):
|
self.assertAllClose(
|
||||||
self.assertAlmostEqual(
|
np.square(labels[:5] - predictions[:5]), loss[:5], atol=1e-4)
|
||||||
_squared_loss(labels[i], predictions[i]), loss[i], places=4)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -80,13 +80,9 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
raise ImportError('googleapiclient must be installed before using the '
|
raise ImportError('googleapiclient must be installed before using the '
|
||||||
'TPU cluster resolver')
|
'TPU cluster resolver')
|
||||||
|
|
||||||
# TODO(b/67375680): Remove custom URL once TPU APIs are finalized
|
|
||||||
self._service = discovery.build(
|
self._service = discovery.build(
|
||||||
'tpu',
|
'tpu', 'v1alpha1',
|
||||||
'v1',
|
credentials=self._credentials)
|
||||||
credentials=self._credentials,
|
|
||||||
discoveryServiceUrl='https://storage.googleapis.com'
|
|
||||||
'/tpu-api-definition/v1alpha1.json')
|
|
||||||
else:
|
else:
|
||||||
self._service = service
|
self._service = service
|
||||||
|
|
||||||
|
@ -305,6 +305,18 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "prefetch_dataset_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["prefetch_dataset_op_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":dataset_serialization_test",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "range_dataset_op_test",
|
name = "range_dataset_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -333,7 +345,7 @@ py_test(
|
|||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "reader_dataset_ops_test",
|
name = "reader_dataset_ops_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["reader_dataset_ops_test.py"],
|
srcs = ["reader_dataset_ops_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = ["no_pip"],
|
tags = ["no_pip"],
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for the experimental input pipeline ops."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchDatasetSerializationTest(
|
||||||
|
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||||
|
|
||||||
|
def build_dataset(self, seed):
|
||||||
|
return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
|
||||||
|
buffer_size=10, seed=seed, reshuffle_each_iteration=False)
|
||||||
|
|
||||||
|
def testCore(self):
|
||||||
|
num_outputs = 100
|
||||||
|
self.run_core_tests(lambda: self.build_dataset(10),
|
||||||
|
lambda: self.build_dataset(20), num_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -151,16 +151,24 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
|
|||||||
return self._fisher_est.damping
|
return self._fisher_est.damping
|
||||||
|
|
||||||
def minimize(self, *args, **kwargs):
|
def minimize(self, *args, **kwargs):
|
||||||
|
kwargs["var_list"] = kwargs.get("var_list") or self.variables
|
||||||
if "var_list" not in kwargs:
|
|
||||||
kwargs["var_list"] = tf_variables.trainable_variables()
|
|
||||||
|
|
||||||
if set(kwargs["var_list"]) != set(self.variables):
|
if set(kwargs["var_list"]) != set(self.variables):
|
||||||
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
||||||
"variables.")
|
"variables.")
|
||||||
|
|
||||||
return super(KfacOptimizer, self).minimize(*args, **kwargs)
|
return super(KfacOptimizer, self).minimize(*args, **kwargs)
|
||||||
|
|
||||||
|
def compute_gradients(self, *args, **kwargs):
|
||||||
|
# args[1] could be our var_list
|
||||||
|
if len(args) > 1:
|
||||||
|
var_list = args[1]
|
||||||
|
else:
|
||||||
|
kwargs["var_list"] = kwargs.get("var_list") or self.variables
|
||||||
|
var_list = kwargs["var_list"]
|
||||||
|
if set(var_list) != set(self.variables):
|
||||||
|
raise ValueError("var_list doesn't match with set of Fisher-estimating "
|
||||||
|
"variables.")
|
||||||
|
return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, *args, **kwargs):
|
def apply_gradients(self, grads_and_vars, *args, **kwargs):
|
||||||
"""Applies gradients to variables.
|
"""Applies gradients to variables.
|
||||||
|
|
||||||
|
@ -312,6 +312,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":graph_optimizer",
|
":graph_optimizer",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:devices",
|
"//tensorflow/core/grappler:devices",
|
||||||
@ -320,6 +321,7 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
|
"//tensorflow/core/grappler/costs:virtual_placer",
|
||||||
"//tensorflow/core/grappler/utils:frame",
|
"//tensorflow/core/grappler/utils:frame",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/frame.h"
|
#include "tensorflow/core/grappler/utils/frame.h"
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
@ -109,11 +111,13 @@ bool IsMaxPoolGradV1(const NodeDef& node) {
|
|||||||
|
|
||||||
class GraphProcessor {
|
class GraphProcessor {
|
||||||
public:
|
public:
|
||||||
GraphProcessor(GraphDef* graph, NodeMap* node_map,
|
GraphProcessor(const VirtualPlacer& virtual_placer,
|
||||||
const std::unordered_set<string>& nodes_to_preserve)
|
const std::unordered_set<string>& nodes_to_preserve,
|
||||||
: graph_(graph),
|
GraphDef* graph, NodeMap* node_map)
|
||||||
node_map_(node_map),
|
: virtual_placer_(virtual_placer),
|
||||||
nodes_to_preserve_(nodes_to_preserve) {}
|
nodes_to_preserve_(nodes_to_preserve),
|
||||||
|
graph_(graph),
|
||||||
|
node_map_(node_map) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
NodeDef* AddNodePermConst(const string& name, const string& device,
|
NodeDef* AddNodePermConst(const string& name, const string& device,
|
||||||
@ -122,7 +126,6 @@ class GraphProcessor {
|
|||||||
node_map_->AddNode(name, node);
|
node_map_->AddNode(name, node);
|
||||||
node->set_name(name);
|
node->set_name(name);
|
||||||
node->set_op("Const");
|
node->set_op("Const");
|
||||||
node->set_device(device);
|
|
||||||
AttrValue attr_data_type;
|
AttrValue attr_data_type;
|
||||||
attr_data_type.set_type(DT_INT32);
|
attr_data_type.set_type(DT_INT32);
|
||||||
node->mutable_attr()->insert({"dtype", attr_data_type});
|
node->mutable_attr()->insert({"dtype", attr_data_type});
|
||||||
@ -133,6 +136,13 @@ class GraphProcessor {
|
|||||||
}
|
}
|
||||||
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
|
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
|
||||||
node->mutable_attr()->insert({"value", attr_tensor});
|
node->mutable_attr()->insert({"value", attr_tensor});
|
||||||
|
string device_name;
|
||||||
|
if (device.empty()) {
|
||||||
|
device_name = virtual_placer_.get_canonical_device_name(*node);
|
||||||
|
} else {
|
||||||
|
device_name = device;
|
||||||
|
}
|
||||||
|
node->set_device(device_name);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,7 +152,6 @@ class GraphProcessor {
|
|||||||
node_map_->AddNode(name, node);
|
node_map_->AddNode(name, node);
|
||||||
node->set_name(name);
|
node->set_name(name);
|
||||||
node->set_op("Const");
|
node->set_op("Const");
|
||||||
node->set_device(device);
|
|
||||||
AttrValue attr_data_type;
|
AttrValue attr_data_type;
|
||||||
attr_data_type.set_type(dtype);
|
attr_data_type.set_type(dtype);
|
||||||
node->mutable_attr()->insert({"dtype", attr_data_type});
|
node->mutable_attr()->insert({"dtype", attr_data_type});
|
||||||
@ -151,6 +160,13 @@ class GraphProcessor {
|
|||||||
tensor.scalar<int>()() = value;
|
tensor.scalar<int>()() = value;
|
||||||
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
|
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
|
||||||
node->mutable_attr()->insert({"value", attr_tensor});
|
node->mutable_attr()->insert({"value", attr_tensor});
|
||||||
|
string device_name;
|
||||||
|
if (device.empty()) {
|
||||||
|
device_name = virtual_placer_.get_canonical_device_name(*node);
|
||||||
|
} else {
|
||||||
|
device_name = device;
|
||||||
|
}
|
||||||
|
node->set_device(device_name);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,7 +175,6 @@ class GraphProcessor {
|
|||||||
node_map_->AddNode(name, node);
|
node_map_->AddNode(name, node);
|
||||||
node->set_name(name);
|
node->set_name(name);
|
||||||
node->set_op("Const");
|
node->set_op("Const");
|
||||||
node->set_device(device);
|
|
||||||
AttrValue attr_data_type;
|
AttrValue attr_data_type;
|
||||||
attr_data_type.set_type(DT_INT32);
|
attr_data_type.set_type(DT_INT32);
|
||||||
node->mutable_attr()->insert({"dtype", attr_data_type});
|
node->mutable_attr()->insert({"dtype", attr_data_type});
|
||||||
@ -172,26 +187,37 @@ class GraphProcessor {
|
|||||||
}
|
}
|
||||||
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
|
tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
|
||||||
node->mutable_attr()->insert({"value", attr_tensor});
|
node->mutable_attr()->insert({"value", attr_tensor});
|
||||||
|
string device_name;
|
||||||
|
if (device.empty()) {
|
||||||
|
device_name = virtual_placer_.get_canonical_device_name(*node);
|
||||||
|
} else {
|
||||||
|
device_name = device;
|
||||||
|
}
|
||||||
|
node->set_device(device_name);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const VirtualPlacer& virtual_placer_;
|
||||||
|
const std::unordered_set<string>& nodes_to_preserve_;
|
||||||
GraphDef* graph_;
|
GraphDef* graph_;
|
||||||
NodeMap* node_map_;
|
NodeMap* node_map_;
|
||||||
const std::unordered_set<string>& nodes_to_preserve_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OptimizeContext {
|
struct OptimizeContext {
|
||||||
OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
|
OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
|
||||||
|
const VirtualPlacer& virtual_placer,
|
||||||
const std::unordered_set<string>& nodes_to_preserve,
|
const std::unordered_set<string>& nodes_to_preserve,
|
||||||
bool is_in_frame)
|
bool is_in_frame)
|
||||||
: graph(graph),
|
: graph(graph),
|
||||||
node(node),
|
node(node),
|
||||||
node_map(node_map),
|
node_map(node_map),
|
||||||
|
virtual_placer(virtual_placer),
|
||||||
nodes_to_preserve(nodes_to_preserve),
|
nodes_to_preserve(nodes_to_preserve),
|
||||||
is_in_frame(is_in_frame) {}
|
is_in_frame(is_in_frame) {}
|
||||||
GraphDef* graph;
|
GraphDef* graph;
|
||||||
NodeDef* node;
|
NodeDef* node;
|
||||||
NodeMap* node_map;
|
NodeMap* node_map;
|
||||||
|
const VirtualPlacer& virtual_placer;
|
||||||
const std::unordered_set<string>& nodes_to_preserve;
|
const std::unordered_set<string>& nodes_to_preserve;
|
||||||
bool is_in_frame;
|
bool is_in_frame;
|
||||||
};
|
};
|
||||||
@ -199,8 +225,8 @@ struct OptimizeContext {
|
|||||||
class NodeProcessor : public GraphProcessor {
|
class NodeProcessor : public GraphProcessor {
|
||||||
public:
|
public:
|
||||||
explicit NodeProcessor(const OptimizeContext& opt_cxt)
|
explicit NodeProcessor(const OptimizeContext& opt_cxt)
|
||||||
: GraphProcessor(opt_cxt.graph, opt_cxt.node_map,
|
: GraphProcessor(opt_cxt.virtual_placer, opt_cxt.nodes_to_preserve,
|
||||||
opt_cxt.nodes_to_preserve),
|
opt_cxt.graph, opt_cxt.node_map),
|
||||||
node_(opt_cxt.node),
|
node_(opt_cxt.node),
|
||||||
is_in_frame_(opt_cxt.is_in_frame) {}
|
is_in_frame_(opt_cxt.is_in_frame) {}
|
||||||
virtual ~NodeProcessor() {}
|
virtual ~NodeProcessor() {}
|
||||||
@ -257,7 +283,25 @@ class NodeProcessor : public GraphProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual bool ShouldProcess() const {
|
virtual bool ShouldProcess() const {
|
||||||
return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs();
|
return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
|
||||||
|
IsOnGPU();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool IsOnGPU() const {
|
||||||
|
string device_name;
|
||||||
|
if (node_->device().empty()) {
|
||||||
|
device_name = virtual_placer_.get_canonical_device_name(*node_);
|
||||||
|
} else {
|
||||||
|
device_name = node_->device();
|
||||||
|
}
|
||||||
|
string device;
|
||||||
|
string not_used;
|
||||||
|
if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
|
||||||
|
(StringPiece(str_util::Lowercase(device)))
|
||||||
|
.contains(str_util::Lowercase(DEVICE_GPU))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateAttrDataFormat() {
|
void UpdateAttrDataFormat() {
|
||||||
@ -536,6 +580,9 @@ class BiasAddGradProcessor : public NodeProcessor {
|
|||||||
if (MustPreserve()) {
|
if (MustPreserve()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!IsOnGPU()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
auto input = node_map_->GetNode(node_->input(0));
|
auto input = node_map_->GetNode(node_->input(0));
|
||||||
if (input) {
|
if (input) {
|
||||||
if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) {
|
if ((IsNHWC() && IsDimsFour(*input)) || IsNodeNCHWToNHWC(input->name())) {
|
||||||
@ -556,7 +603,7 @@ class Conv2DProcessor : public NodeProcessor {
|
|||||||
protected:
|
protected:
|
||||||
bool ShouldProcess() const override {
|
bool ShouldProcess() const override {
|
||||||
return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
|
return !MustPreserve() && IsNHWC() && IsDimsFour(*node_) && HasOutputs() &&
|
||||||
(!IsGemmUsed() || no_gemm_);
|
(!IsGemmUsed() || no_gemm_) && IsOnGPU();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShapeProto GetShape(const string& input_name) const {
|
TensorShapeProto GetShape(const string& input_name) const {
|
||||||
@ -667,10 +714,24 @@ class FusedBatchNormGradProcessor : public NodeProcessor {
|
|||||||
: NodeProcessor(opt_cxt) {}
|
: NodeProcessor(opt_cxt) {}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
bool ShouldProcess() const override {
|
||||||
|
return NodeProcessor::ShouldProcess() && IsTraining();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> GetInputPos() const override {
|
std::vector<int> GetInputPos() const override {
|
||||||
std::vector<int> input_pos = {0, 1};
|
std::vector<int> input_pos = {0, 1};
|
||||||
return input_pos;
|
return input_pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IsTraining() const {
|
||||||
|
if (node_->attr().find("is_training") != node_->attr().end()) {
|
||||||
|
if (node_->attr().at("is_training").b()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MaxPoolGradProcessor : public NodeProcessor {
|
class MaxPoolGradProcessor : public NodeProcessor {
|
||||||
@ -693,7 +754,7 @@ class AgnosticNodeProcessor : public NodeProcessor {
|
|||||||
protected:
|
protected:
|
||||||
bool ShouldProcess() const override {
|
bool ShouldProcess() const override {
|
||||||
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
||||||
IsNodeAfterNCHWToNHWC();
|
IsNodeAfterNCHWToNHWC() && IsOnGPU();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsNodeAfterNCHWToNHWC() const {
|
bool IsNodeAfterNCHWToNHWC() const {
|
||||||
@ -746,7 +807,8 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
|
|||||||
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
||||||
IsNodeAfterNCHWToNHWC() &&
|
IsNodeAfterNCHWToNHWC() &&
|
||||||
(Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
|
(Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
|
||||||
Is4DOperateWithVector());
|
Is4DOperateWithVector()) &&
|
||||||
|
IsOnGPU();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> GetInputPos() const override {
|
std::vector<int> GetInputPos() const override {
|
||||||
@ -855,7 +917,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
|
|||||||
protected:
|
protected:
|
||||||
bool ShouldProcess() const override {
|
bool ShouldProcess() const override {
|
||||||
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
||||||
IsNodeAfterNCHWToNHWC() && IsAlongDimC();
|
IsNodeAfterNCHWToNHWC() && IsAlongDimC() && IsOnGPU();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> GetInputPos() const override {
|
std::vector<int> GetInputPos() const override {
|
||||||
@ -920,7 +982,7 @@ class PadProcessor : public AgnosticNodeProcessor {
|
|||||||
protected:
|
protected:
|
||||||
bool ShouldProcess() const override {
|
bool ShouldProcess() const override {
|
||||||
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
|
||||||
IsNodeAfterNCHWToNHWC() && PaddingSupported();
|
IsNodeAfterNCHWToNHWC() && PaddingSupported() && IsOnGPU();
|
||||||
}
|
}
|
||||||
Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
|
Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
|
||||||
|
|
||||||
@ -1132,7 +1194,8 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
|
|||||||
protected:
|
protected:
|
||||||
bool ShouldProcess() const override {
|
bool ShouldProcess() const override {
|
||||||
return !MustPreserve() && IsDimsN(*node_, 2) && HasOutputs() &&
|
return !MustPreserve() && IsDimsN(*node_, 2) && HasOutputs() &&
|
||||||
IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW();
|
IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() &&
|
||||||
|
IsOnGPU();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
||||||
@ -1183,7 +1246,7 @@ class SumProcessor : public AgnosticNodeProcessor {
|
|||||||
auto input0 = node_map_->GetNode(node_->input(0));
|
auto input0 = node_map_->GetNode(node_->input(0));
|
||||||
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
|
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
|
||||||
(IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
|
(IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
|
||||||
IsAlongDimNHW();
|
IsAlongDimNHW() && IsOnGPU();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
||||||
@ -1243,42 +1306,41 @@ class SumProcessor : public AgnosticNodeProcessor {
|
|||||||
class DataLayoutOptimizer : GraphProcessor {
|
class DataLayoutOptimizer : GraphProcessor {
|
||||||
public:
|
public:
|
||||||
explicit DataLayoutOptimizer(
|
explicit DataLayoutOptimizer(
|
||||||
LayoutOptimizer::TuningConfig config,
|
const VirtualPlacer& virtual_placer,
|
||||||
const std::unordered_set<string>& nodes_to_preserve,
|
const LayoutOptimizer::TuningConfig& config,
|
||||||
const string& default_device, GraphDef* graph, NodeMap* node_map)
|
const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
|
||||||
: GraphProcessor(graph, node_map, nodes_to_preserve),
|
NodeMap* node_map)
|
||||||
config_(config),
|
: GraphProcessor(virtual_placer, nodes_to_preserve, graph, node_map),
|
||||||
default_device_(default_device) {}
|
config_(config) {}
|
||||||
|
|
||||||
Status Optimize() {
|
Status Optimize() {
|
||||||
LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
|
VLOG(1) << "Number of nodes for original graph: " << graph_->node_size();
|
||||||
TF_RETURN_IF_ERROR(Expand());
|
TF_RETURN_IF_ERROR(Expand());
|
||||||
LOG(INFO) << "Number of nodes after Expand: " << graph_->node_size();
|
VLOG(1) << "Number of nodes after Expand: " << graph_->node_size();
|
||||||
TF_RETURN_IF_ERROR(Collapse());
|
TF_RETURN_IF_ERROR(Collapse());
|
||||||
LOG(INFO) << "Number of nodes after Collapse: " << graph_->node_size();
|
VLOG(1) << "Number of nodes after Collapse: " << graph_->node_size();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NodeDef* AddNodePermNHWCToNCHW() {
|
NodeDef* AddNodePermNHWCToNCHW() {
|
||||||
return AddNodePermConst(kPermNHWCToNCHW, default_device_, {0, 3, 1, 2});
|
return AddNodePermConst(kPermNHWCToNCHW, "", {0, 3, 1, 2});
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* AddNodePermNCHWToNHWC() {
|
NodeDef* AddNodePermNCHWToNHWC() {
|
||||||
return AddNodePermConst(kPermNCHWToNHWC, default_device_, {0, 2, 3, 1});
|
return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* AddNodeConcatConst() {
|
NodeDef* AddNodeConcatConst() {
|
||||||
return AddNodeConstScalar(kConcatConst, default_device_, DT_INT32, 1);
|
return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* AddNodeGatherAxisConst() {
|
NodeDef* AddNodeGatherAxisConst() {
|
||||||
return AddNodeConstScalar(kGatherAxisConst, default_device_, DT_INT32, 0);
|
return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* AddNodeReductionConst() {
|
NodeDef* AddNodeReductionConst() {
|
||||||
return GraphProcessor::AddNodeReductionConst(kReductionConst,
|
return GraphProcessor::AddNodeReductionConst(kReductionConst, "");
|
||||||
default_device_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
|
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
|
||||||
@ -1295,8 +1357,8 @@ class DataLayoutOptimizer : GraphProcessor {
|
|||||||
ops_format_supported.end()) {
|
ops_format_supported.end()) {
|
||||||
auto node = graph_->mutable_node(i);
|
auto node = graph_->mutable_node(i);
|
||||||
bool is_in_frame = !frames[node].empty();
|
bool is_in_frame = !frames[node].empty();
|
||||||
OptimizeContext opt_cxt(graph_, node, node_map_, nodes_to_preserve_,
|
OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
|
||||||
is_in_frame);
|
nodes_to_preserve_, is_in_frame);
|
||||||
std::unique_ptr<NodeProcessor> node_processor;
|
std::unique_ptr<NodeProcessor> node_processor;
|
||||||
if (IsAvgPoolGrad(*node)) {
|
if (IsAvgPoolGrad(*node)) {
|
||||||
node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
|
node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
|
||||||
@ -1343,8 +1405,8 @@ class DataLayoutOptimizer : GraphProcessor {
|
|||||||
ops_format_agnostic.end()) {
|
ops_format_agnostic.end()) {
|
||||||
auto node = graph_->mutable_node(i);
|
auto node = graph_->mutable_node(i);
|
||||||
bool is_in_frame = !frames[node].empty();
|
bool is_in_frame = !frames[node].empty();
|
||||||
OptimizeContext opt_cxt(graph_, node, node_map_, nodes_to_preserve_,
|
OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
|
||||||
is_in_frame);
|
nodes_to_preserve_, is_in_frame);
|
||||||
std::unique_ptr<NodeProcessor> node_processor;
|
std::unique_ptr<NodeProcessor> node_processor;
|
||||||
if (IsAddN(*node)) {
|
if (IsAddN(*node)) {
|
||||||
node_processor.reset(new AddNProcessor(opt_cxt));
|
node_processor.reset(new AddNProcessor(opt_cxt));
|
||||||
@ -1419,8 +1481,7 @@ class DataLayoutOptimizer : GraphProcessor {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
LayoutOptimizer::TuningConfig config_;
|
const LayoutOptimizer::TuningConfig& config_;
|
||||||
string default_device_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int GetNumTranspose(const GraphDef& graph) {
|
int GetNumTranspose(const GraphDef& graph) {
|
||||||
@ -1430,7 +1491,7 @@ int GetNumTranspose(const GraphDef& graph) {
|
|||||||
number++;
|
number++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG(INFO) << "Number of Transpose nodes: " << number;
|
VLOG(1) << "Number of Transpose nodes: " << number;
|
||||||
return number;
|
return number;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1455,7 +1516,6 @@ int GetNumGPUs(const Cluster& cluster) {
|
|||||||
|
|
||||||
Status LayoutOptimizer::Tune(const GrapplerItem& item,
|
Status LayoutOptimizer::Tune(const GrapplerItem& item,
|
||||||
const GraphProperties& graph_properties,
|
const GraphProperties& graph_properties,
|
||||||
const string& default_device,
|
|
||||||
const TuningConfig& config, GraphDef* output) {
|
const TuningConfig& config, GraphDef* output) {
|
||||||
auto status = graph_properties.AnnotateOutputShapes(output);
|
auto status = graph_properties.AnnotateOutputShapes(output);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
@ -1463,8 +1523,8 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
NodeMap node_map(output);
|
NodeMap node_map(output);
|
||||||
DataLayoutOptimizer layout_optimizer(config, nodes_to_preserve_,
|
DataLayoutOptimizer layout_optimizer(*virtual_placer_, config,
|
||||||
default_device, output, &node_map);
|
nodes_to_preserve_, output, &node_map);
|
||||||
status = layout_optimizer.Optimize();
|
status = layout_optimizer.Optimize();
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
@ -1477,6 +1537,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual_placer_.reset(new VirtualPlacer(cluster));
|
||||||
nodes_to_preserve_ = item.NodesToPreserve();
|
nodes_to_preserve_ = item.NodesToPreserve();
|
||||||
GraphProperties graph_properties(item);
|
GraphProperties graph_properties(item);
|
||||||
auto status = graph_properties.InferStatically();
|
auto status = graph_properties.InferStatically();
|
||||||
@ -1487,20 +1548,13 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
TuningConfig config;
|
TuningConfig config;
|
||||||
config.no_gemm = false;
|
config.no_gemm = false;
|
||||||
string default_device = "/job:localhost/replica:0/task:0/cpu:0";
|
status = Tune(item, graph_properties, config, output);
|
||||||
if (cluster) {
|
|
||||||
if (!cluster->GetDevices().empty()) {
|
|
||||||
default_device = cluster->GetDevices().begin()->first;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
status = Tune(item, graph_properties, default_device, config, output);
|
|
||||||
// This is based on an empirical observation that if the introduced Transpose
|
// This is based on an empirical observation that if the introduced Transpose
|
||||||
// nodes is more than 30, not using GEMM implementation would result in better
|
// nodes is more than 30, not using GEMM implementation would result in better
|
||||||
// performance.
|
// performance.
|
||||||
if (status.ok() && GetNumTranspose(*output) > 30) {
|
if (status.ok() && GetNumTranspose(*output) > 30) {
|
||||||
config.no_gemm = true;
|
config.no_gemm = true;
|
||||||
status = Tune(item, graph_properties, default_device, config, output);
|
status = Tune(item, graph_properties, config, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
|
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -47,10 +48,10 @@ class LayoutOptimizer : public GraphOptimizer {
|
|||||||
const GraphDef& optimize_output, double result) override;
|
const GraphDef& optimize_output, double result) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::unique_ptr<VirtualPlacer> virtual_placer_;
|
||||||
std::unordered_set<string> nodes_to_preserve_;
|
std::unordered_set<string> nodes_to_preserve_;
|
||||||
Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties,
|
Status Tune(const GrapplerItem& item, const GraphProperties& graph_properties,
|
||||||
const string& default_device, const TuningConfig& config,
|
const TuningConfig& config, GraphDef* output);
|
||||||
GraphDef* output);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
|
@ -39,6 +39,11 @@ class LayoutOptimizerTest : public ::testing::Test {
|
|||||||
|
|
||||||
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||||
const string& padding) {
|
const string& padding) {
|
||||||
|
return SimpleConv2D(s, input_size, filter_size, padding, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||||
|
const string& padding, const string& device) {
|
||||||
int batch_size = 128;
|
int batch_size = 128;
|
||||||
int input_height = input_size;
|
int input_height = input_size;
|
||||||
int input_width = input_size;
|
int input_width = input_size;
|
||||||
@ -59,8 +64,8 @@ class LayoutOptimizerTest : public ::testing::Test {
|
|||||||
Output filter =
|
Output filter =
|
||||||
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
||||||
|
|
||||||
Output conv = ops::Conv2D(s->WithOpName("Conv2D"), input, filter,
|
Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
|
||||||
{1, stride, stride, 1}, padding);
|
filter, {1, stride, stride, 1}, padding);
|
||||||
return conv;
|
return conv;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,6 +114,36 @@ class LayoutOptimizerTest : public ::testing::Test {
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
|
||||||
|
int batch_size = 16;
|
||||||
|
int input_height = 8;
|
||||||
|
int input_width = 8;
|
||||||
|
int input_channels = 3;
|
||||||
|
TensorShape shape({batch_size, input_height, input_width, input_channels});
|
||||||
|
Tensor data(DT_FLOAT, shape);
|
||||||
|
test::FillIota<float>(&data, 1.0f);
|
||||||
|
Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
|
||||||
|
Output y_backprop =
|
||||||
|
ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
|
||||||
|
|
||||||
|
TensorShape shape_vector({input_channels});
|
||||||
|
Tensor data_vector(DT_FLOAT, shape_vector);
|
||||||
|
test::FillIota<float>(&data_vector, 2.0f);
|
||||||
|
Output scale =
|
||||||
|
ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
|
||||||
|
Output reserve1 =
|
||||||
|
ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
|
||||||
|
Output reserve2 =
|
||||||
|
ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
|
||||||
|
|
||||||
|
ops::FusedBatchNormGrad::Attrs attrs;
|
||||||
|
attrs.is_training_ = is_training;
|
||||||
|
auto output =
|
||||||
|
ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
|
||||||
|
x, scale, reserve1, reserve2, attrs);
|
||||||
|
return output.x_backprop;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<VirtualCluster> virtual_cluster_;
|
std::unique_ptr<VirtualCluster> virtual_cluster_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -278,6 +313,92 @@ TEST_F(LayoutOptimizerTest, PreserveFetch) {
|
|||||||
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, EmptyDevice) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto conv = SimpleConv2D(&s, 3, 2, "VALID");
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
auto conv_node = node_map.GetNode("Conv2D");
|
||||||
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, GPUDevice) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto conv =
|
||||||
|
SimpleConv2D(&s, 3, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
auto conv_node = node_map.GetNode("Conv2D");
|
||||||
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto conv =
|
||||||
|
SimpleConv2D(&s, 3, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
auto conv_node = node_map.GetNode("Conv2D");
|
||||||
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto conv = SimpleConv2D(&s, 3, 2, "VALID", "/CPU:0");
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
auto conv_node = node_map.GetNode("Conv2D");
|
||||||
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
auto conv_node = node_map.GetNode("FusedBatchNormGrad");
|
||||||
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
auto conv_node = node_map.GetNode("FusedBatchNormGrad");
|
||||||
|
EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -6059,6 +6059,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -79,16 +79,20 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
Status GetNextInternal(IteratorContext* ctx,
|
Status GetNextInternal(IteratorContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) override {
|
bool* end_of_sequence) override {
|
||||||
if (!input_impl_) {
|
{
|
||||||
*end_of_sequence = true;
|
tf_shared_lock l(mu_);
|
||||||
return Status::OK();
|
if (!input_impl_) {
|
||||||
}
|
*end_of_sequence = true;
|
||||||
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
return Status::OK();
|
||||||
while (!s.ok()) {
|
}
|
||||||
out_tensors->clear();
|
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
while (!s.ok()) {
|
||||||
|
out_tensors->clear();
|
||||||
|
s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (*end_of_sequence) {
|
if (*end_of_sequence) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
input_impl_.reset();
|
input_impl_.reset();
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -96,6 +100,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
if (input_impl_)
|
if (input_impl_)
|
||||||
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||||
else
|
else
|
||||||
@ -106,6 +111,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
Status RestoreInternal(OpKernelContext* ctx,
|
Status RestoreInternal(OpKernelContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
if (reader->Contains(full_name("input_impls_empty")))
|
if (reader->Contains(full_name("input_impls_empty")))
|
||||||
input_impl_.reset();
|
input_impl_.reset();
|
||||||
else
|
else
|
||||||
@ -114,7 +120,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<IteratorBase> input_impl_;
|
mutex mu_;
|
||||||
|
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
|
@ -14,9 +14,10 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/dataset.h"
|
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/dataset.h"
|
||||||
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -39,14 +40,14 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
OP_REQUIRES(ctx, buffer_size > 0,
|
OP_REQUIRES(ctx, buffer_size > 0,
|
||||||
errors::InvalidArgument("buffer_size must be > 0"));
|
errors::InvalidArgument("buffer_size must be > 0"));
|
||||||
|
|
||||||
*output = new Dataset(input, buffer_size);
|
*output = new Dataset(ctx, input, buffer_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Dataset : public DatasetBase {
|
class Dataset : public GraphDatasetBase {
|
||||||
public:
|
public:
|
||||||
Dataset(const DatasetBase* input, int64 buffer_size)
|
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size)
|
||||||
: input_(input), buffer_size_(buffer_size) {
|
: GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
|
||||||
input_->Ref();
|
input_->Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,6 +68,18 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
string DebugString() override { return "PrefetchDatasetOp::Dataset"; }
|
string DebugString() override { return "PrefetchDatasetOp::Dataset"; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
|
||||||
|
Node** output) const override {
|
||||||
|
Node* input_graph_node = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
|
||||||
|
Node* buffer_size = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
b->AddDataset(this, {input_graph_node, buffer_size}, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Iterator : public DatasetIterator<Dataset> {
|
class Iterator : public DatasetIterator<Dataset> {
|
||||||
public:
|
public:
|
||||||
@ -121,7 +134,10 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
|
|
||||||
// Wake the prefetch thread, in case it has been waiting
|
// Wake the prefetch thread, in case it has been waiting
|
||||||
// for space in the buffer.
|
// for space in the buffer.
|
||||||
cond_var_.notify_one();
|
// Also wake up threads from other calls to GetNext.
|
||||||
|
// TODO(mrry): Consider using different condition variables
|
||||||
|
// for GetNext and Prefetch.
|
||||||
|
cond_var_.notify_all();
|
||||||
return s;
|
return s;
|
||||||
} else if (prefetch_thread_finished_) {
|
} else if (prefetch_thread_finished_) {
|
||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
@ -130,6 +146,69 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
|
// Acquire both locks to ensure that the prefetch thread and
|
||||||
|
// all GetNext threads are blocked.
|
||||||
|
mutex_lock parent_l(parent_mu_);
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
|
||||||
|
for (size_t i = 0; i < buffer_.size(); i++) {
|
||||||
|
auto& buffer_element = buffer_[i];
|
||||||
|
TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
|
||||||
|
if (buffer_element.status.ok()) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
|
full_name(strings::StrCat("buffer[", i, "].size")),
|
||||||
|
buffer_element.value.size()));
|
||||||
|
for (size_t j = 0; j < buffer_element.value.size(); j++) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteTensor(
|
||||||
|
strings::StrCat("buffer[", i, "][", j, "]"),
|
||||||
|
buffer_element.value[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RestoreInternal(OpKernelContext* ctx,
|
||||||
|
IteratorStateReader* reader) override {
|
||||||
|
mutex_lock parent_l(parent_mu_);
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
buffer_.clear();
|
||||||
|
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
|
||||||
|
size_t buffer_size;
|
||||||
|
{
|
||||||
|
int64 temp;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(full_name("buffer_size"), &temp));
|
||||||
|
buffer_size = static_cast<size_t>(temp);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < buffer_size; i++) {
|
||||||
|
buffer_.emplace_back();
|
||||||
|
auto& buffer_element = buffer_.back();
|
||||||
|
TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
|
||||||
|
if (buffer_element.status.ok()) {
|
||||||
|
size_t value_size;
|
||||||
|
{
|
||||||
|
int64 temp;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(
|
||||||
|
full_name(strings::StrCat("buffer[", i, "].size")), &temp));
|
||||||
|
value_size = static_cast<size_t>(temp);
|
||||||
|
}
|
||||||
|
buffer_element.value.reserve(value_size);
|
||||||
|
for (size_t j = 0; j < value_size; j++) {
|
||||||
|
buffer_element.value.emplace_back();
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadTensor(
|
||||||
|
strings::StrCat("buffer[", i, "][", j, "]"),
|
||||||
|
&buffer_element.value.back()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A buffer element comprises a status and (if that status is
|
// A buffer element comprises a status and (if that status is
|
||||||
// OK) a vector of tensors, representing an element of the input dataset.
|
// OK) a vector of tensors, representing an element of the input dataset.
|
||||||
@ -173,6 +252,12 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Read the next element.
|
// 2. Read the next element.
|
||||||
|
// Acquire the parent lock since we will be reading an element
|
||||||
|
// from the input iterator. Note that we do not wish to release
|
||||||
|
// this lock till we have added the fetched element to the
|
||||||
|
// `buffer_` else there will be local state that may be missed
|
||||||
|
// by SaveInternal.
|
||||||
|
mutex_lock parent_l(parent_mu_);
|
||||||
bool end_of_sequence;
|
bool end_of_sequence;
|
||||||
BufferElement buffer_element;
|
BufferElement buffer_element;
|
||||||
buffer_element.status = input_impl_->GetNext(
|
buffer_element.status = input_impl_->GetNext(
|
||||||
@ -193,8 +278,50 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status WriteStatus(IteratorStateWriter* writer, size_t index,
|
||||||
|
const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
|
CodeKey(index), static_cast<int64>(status.code())));
|
||||||
|
if (!status.ok()) {
|
||||||
|
TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
|
||||||
|
status.error_message()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReadStatus(IteratorStateReader* reader, size_t index,
|
||||||
|
Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
int64 code_int;
|
||||||
|
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
|
||||||
|
error::Code code = static_cast<error::Code>(code_int);
|
||||||
|
|
||||||
|
if (code != error::Code::OK) {
|
||||||
|
string error_message;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
reader->ReadScalar(ErrorMessageKey(index), &error_message));
|
||||||
|
*status = Status(code, error_message);
|
||||||
|
} else {
|
||||||
|
*status = Status::OK();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
string CodeKey(size_t index) {
|
||||||
|
return full_name(strings::StrCat("status[", index, "].code"));
|
||||||
|
}
|
||||||
|
|
||||||
|
string ErrorMessageKey(size_t index) {
|
||||||
|
return full_name(strings::StrCat("status[", index, "].error_message"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// This mutex is used to ensure exclusivity between multiple threads
|
||||||
|
// reading/writing this iterator's local state.
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
const std::unique_ptr<IteratorBase> input_impl_;
|
// This mutex is used to ensure exclusivity between multiple threads
|
||||||
|
// accessing the parent iterator. We keep this separate from `mu_` to
|
||||||
|
// allow prefetching to run in parallel with GetNext calls.
|
||||||
|
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
|
||||||
|
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
|
||||||
condition_variable cond_var_;
|
condition_variable cond_var_;
|
||||||
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
|
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
|
||||||
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
|
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
|
||||||
|
@ -4352,7 +4352,10 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [":pywrap_tensorflow_internal"],
|
deps = [
|
||||||
|
":pywrap_tensorflow_internal",
|
||||||
|
":tf_cluster",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -798,13 +798,41 @@ class GradientTape(object):
|
|||||||
grad = g.gradient(y, [x])[0]
|
grad = g.gradient(y, [x])[0]
|
||||||
assert grad.numpy() == 6.0
|
assert grad.numpy() == 6.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
By default, the resources held by a GradientTape are released as soon as
|
||||||
|
GradientTape.gradient() method is called. However, if one need to compute
|
||||||
|
multiple gradients over the same computation, she can create a persistent
|
||||||
|
GradientTape. Persistent tapes allow multiple calls to the gradient() method
|
||||||
|
and release resources when the tape object is destructed.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with tfe.GradientTape(persistent=True) as g:
|
||||||
|
x = tf.constant(3.0)
|
||||||
|
g.watch(x)
|
||||||
|
y = x * x
|
||||||
|
z = y * y
|
||||||
|
dz_dx = g.gradient(z, [x])[0]
|
||||||
|
assert dz_dx.numpy() == 108.0 # 4*x^3 at x = 3
|
||||||
|
dy_dx = g.gradient(y, [x])[0]
|
||||||
|
assert dy_dx.numpy() == 6.0
|
||||||
|
del g # Drop the reference to the tape
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, persistent=False):
|
||||||
|
"""Creates a new GradientTape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
persistent: Boolean controlling whether a persistent gradient tape
|
||||||
|
is created. Must be True or False.
|
||||||
|
|
||||||
|
"""
|
||||||
self._tape = None
|
self._tape = None
|
||||||
|
self._persistent = persistent
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
tape.push_new_tape()
|
tape.push_new_tape(persistent=self._persistent)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, typ, value, traceback):
|
def __exit__(self, typ, value, traceback):
|
||||||
@ -838,12 +866,14 @@ class GradientTape(object):
|
|||||||
than once.
|
than once.
|
||||||
"""
|
"""
|
||||||
if self._tape is None:
|
if self._tape is None:
|
||||||
raise RuntimeError("GradientTape.gradient can only be called once, and "
|
raise RuntimeError("GradientTape.gradient can only be called once "
|
||||||
|
"on non-persistent tapes, and "
|
||||||
"only when the context manager has exited.")
|
"only when the context manager has exited.")
|
||||||
sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
|
sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
|
||||||
else x
|
else x
|
||||||
for x in sources]
|
for x in sources]
|
||||||
grad = imperative_grad.imperative_grad(
|
grad = imperative_grad.imperative_grad(
|
||||||
_default_vspace, self._tape, [target], sources)
|
_default_vspace, self._tape, [target], sources)
|
||||||
self._tape = None
|
if not self._persistent:
|
||||||
|
self._tape = None
|
||||||
return grad
|
return grad
|
||||||
|
@ -314,6 +314,37 @@ class BackpropTest(test.TestCase):
|
|||||||
RuntimeError, 'GradientTape.gradient can only be called once'):
|
RuntimeError, 'GradientTape.gradient can only be called once'):
|
||||||
g.gradient(y, [x])
|
g.gradient(y, [x])
|
||||||
|
|
||||||
|
def testPersistentTape(self):
|
||||||
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
|
x = constant_op.constant(3.0)
|
||||||
|
g.watch(x)
|
||||||
|
y = x * x
|
||||||
|
z = y * y
|
||||||
|
dz_dx = g.gradient(z, [x])[0]
|
||||||
|
self.assertEqual(dz_dx.numpy(), 4*3*3*3)
|
||||||
|
dy_dx = g.gradient(y, [x])[0]
|
||||||
|
self.assertEqual(dy_dx.numpy(), 2*3)
|
||||||
|
del g
|
||||||
|
|
||||||
|
def testPersistentNestedTape(self):
|
||||||
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
|
x = constant_op.constant(3.0)
|
||||||
|
g.watch(x)
|
||||||
|
y = x * x
|
||||||
|
with backprop.GradientTape(persistent=True) as gg:
|
||||||
|
gg.watch(y)
|
||||||
|
z = 2 * y
|
||||||
|
for _ in range(2):
|
||||||
|
inner_grad = gg.gradient(z, [y])[0]
|
||||||
|
self.assertEqual(inner_grad.numpy(), 2.0)
|
||||||
|
y += inner_grad
|
||||||
|
del gg
|
||||||
|
grad = g.gradient(y, [x])[0]
|
||||||
|
self.assertEqual(grad.numpy(), 6.0)
|
||||||
|
grad = g.gradient(z, [x])[0]
|
||||||
|
self.assertEqual(grad.numpy(), 12.0)
|
||||||
|
del g
|
||||||
|
|
||||||
def testGradientTapeVariable(self):
|
def testGradientTapeVariable(self):
|
||||||
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
v = resource_variable_ops.ResourceVariable(1.0, name='v')
|
||||||
with backprop.GradientTape() as g:
|
with backprop.GradientTape() as g:
|
||||||
|
@ -88,7 +88,8 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
|
|||||||
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
|
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
|
||||||
|
|
||||||
// Pushes a new tape into the thread-local stack.
|
// Pushes a new tape into the thread-local stack.
|
||||||
void TFE_Py_TapeStackPushNew();
|
// `persistent` must be a PyBool_Type, i.e either Py_True or Py_False
|
||||||
|
void TFE_Py_TapeStackPushNew(PyObject* persistent);
|
||||||
|
|
||||||
// Pops the tape from the top of the stack and returns it.
|
// Pops the tape from the top of the stack and returns it.
|
||||||
PyObject* TFE_Py_TapeStackPop();
|
PyObject* TFE_Py_TapeStackPop();
|
||||||
|
@ -469,7 +469,8 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
|
|||||||
class GradientTape
|
class GradientTape
|
||||||
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
|
: public tensorflow::eager::GradientTape<PyObject, PyObject> {
|
||||||
public:
|
public:
|
||||||
GradientTape() {}
|
explicit GradientTape(bool persistent)
|
||||||
|
: tensorflow::eager::GradientTape<PyObject, PyObject>(persistent) {}
|
||||||
|
|
||||||
void WatchVariable(PyObject* v) {
|
void WatchVariable(PyObject* v) {
|
||||||
watched_variables_.insert(v);
|
watched_variables_.insert(v);
|
||||||
@ -557,11 +558,11 @@ std::vector<TFE_Py_Tape*>* GetTapeStack() {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void TFE_Py_TapeStackPushNew() {
|
void TFE_Py_TapeStackPushNew(PyObject* persistent) {
|
||||||
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
|
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
|
||||||
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
|
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return;
|
||||||
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
|
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
|
||||||
tape->tape = new GradientTape();
|
tape->tape = new GradientTape(persistent == Py_True);
|
||||||
GetTapeStack()->push_back(tape);
|
GetTapeStack()->push_back(tape);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -704,6 +705,7 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
|
|||||||
PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
|
PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
|
||||||
list.push_back(FastTensorId(tensor));
|
list.push_back(FastTensorId(tensor));
|
||||||
if (PyErr_Occurred()) {
|
if (PyErr_Occurred()) {
|
||||||
|
Py_DECREF(seq);
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -889,7 +891,6 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
|
|||||||
PyObject* py_result = PyEval_CallObject(
|
PyObject* py_result = PyEval_CallObject(
|
||||||
reinterpret_cast<PyObject*>(backward_function), grads);
|
reinterpret_cast<PyObject*>(backward_function), grads);
|
||||||
Py_DECREF(grads);
|
Py_DECREF(grads);
|
||||||
Py_DECREF(backward_function);
|
|
||||||
if (py_result == nullptr) {
|
if (py_result == nullptr) {
|
||||||
return tensorflow::errors::Internal("gradient function threw exceptions");
|
return tensorflow::errors::Internal("gradient function threw exceptions");
|
||||||
}
|
}
|
||||||
@ -917,6 +918,10 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReleaseBackwardFunction(PyObject* backward_function) const final {
|
||||||
|
Py_DECREF(backward_function);
|
||||||
|
}
|
||||||
|
|
||||||
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
|
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -33,9 +33,9 @@ class Tape(object):
|
|||||||
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
|
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
|
||||||
|
|
||||||
|
|
||||||
def push_new_tape():
|
def push_new_tape(persistent=False):
|
||||||
"""Pushes a new tape onto the tape stack."""
|
"""Pushes a new tape onto the tape stack."""
|
||||||
pywrap_tensorflow.TFE_Py_TapeStackPushNew()
|
pywrap_tensorflow.TFE_Py_TapeStackPushNew(persistent)
|
||||||
|
|
||||||
|
|
||||||
def watch(tensor):
|
def watch(tensor):
|
||||||
|
@ -14,6 +14,14 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
%include <std_shared_ptr.i>
|
||||||
|
%include "item.i"
|
||||||
|
|
||||||
|
// Wrap the cluster into an object that swig can manipulate. This ensures it will call the object
|
||||||
|
// destructor upon garbage collection instead of leaking memory.
|
||||||
|
struct GCluster {
|
||||||
|
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
|
||||||
|
};
|
||||||
|
|
||||||
%{
|
%{
|
||||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||||
@ -72,6 +80,7 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
%{
|
%{
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tensorflow/core/grappler/devices.h"
|
#include "tensorflow/core/grappler/devices.h"
|
||||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||||
@ -82,39 +91,56 @@ bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
|
|||||||
#include "tensorflow/core/grappler/costs/utils.h"
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||||
|
|
||||||
static tensorflow::grappler::Cluster* TF_NewCluster(
|
// Provide the implementation of the GCluster struct here.
|
||||||
bool allow_soft_placement,
|
struct GCluster {
|
||||||
bool disable_detailed_stats, TF_Status* out_status) {
|
GCluster() {}
|
||||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {}
|
||||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();;
|
|
||||||
|
tensorflow::grappler::Cluster* operator->() const {
|
||||||
|
return cluster_.get();
|
||||||
|
}
|
||||||
|
tensorflow::grappler::Cluster* get() const {
|
||||||
|
return cluster_.get();
|
||||||
|
}
|
||||||
|
bool is_none() const {
|
||||||
|
return cluster_.get() == nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
static GCluster TF_NewCluster(bool allow_soft_placement,
|
||||||
|
bool disable_detailed_stats, TF_Status* out_status) {
|
||||||
|
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||||
|
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||||
int timeout_s = 60 * 10;
|
int timeout_s = 60 * 10;
|
||||||
tensorflow::grappler::Cluster* cluster =
|
tensorflow::grappler::Cluster* cluster_ =
|
||||||
new tensorflow::grappler::SingleMachine(
|
new tensorflow::grappler::SingleMachine(
|
||||||
timeout_s, num_cpu_cores, num_gpus);
|
timeout_s, num_cpu_cores, num_gpus);
|
||||||
cluster->DisableDetailedStats(disable_detailed_stats);
|
cluster_->DisableDetailedStats(disable_detailed_stats);
|
||||||
cluster->AllowSoftPlacement(allow_soft_placement);
|
cluster_->AllowSoftPlacement(allow_soft_placement);
|
||||||
tensorflow::Status status = cluster->Provision();
|
tensorflow::Status status = cluster_->Provision();
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||||
return cluster;
|
return GCluster(cluster_);
|
||||||
}
|
}
|
||||||
|
|
||||||
static tensorflow::grappler::Cluster* TF_NewVirtualCluster(
|
static GCluster TF_NewVirtualCluster(
|
||||||
const std::vector<tensorflow::NamedDevice>& named_devices,
|
const std::vector<tensorflow::NamedDevice>& named_devices,
|
||||||
TF_Status* out_status) {
|
TF_Status* out_status) {
|
||||||
std::unordered_map<string, tensorflow::DeviceProperties> devices;
|
std::unordered_map<string, tensorflow::DeviceProperties> devices;
|
||||||
for (const auto& named_device : named_devices) {
|
for (const auto& named_device : named_devices) {
|
||||||
devices[named_device.name()]= named_device.properties();
|
devices[named_device.name()]= named_device.properties();
|
||||||
}
|
}
|
||||||
tensorflow::grappler::Cluster* cluster =
|
tensorflow::grappler::Cluster*cluster_ =
|
||||||
new tensorflow::grappler::VirtualCluster(devices);
|
new tensorflow::grappler::VirtualCluster(devices);
|
||||||
tensorflow::Status status = cluster->Provision();
|
tensorflow::Status status = cluster_->Provision();
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||||
return cluster;
|
return GCluster(cluster_);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void TF_DeleteCluster(tensorflow::grappler::Cluster* cluster) {
|
static void TF_ShutdownCluster(GCluster cluster) {
|
||||||
cluster->Shutdown();
|
cluster->Shutdown();
|
||||||
delete cluster;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status _GetOpPerformanceDataAndRunTime(
|
tensorflow::Status _GetOpPerformanceDataAndRunTime(
|
||||||
@ -136,8 +162,9 @@ tensorflow::Status _GetOpPerformanceDataAndRunTime(
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) {
|
static PyObject* TF_ListDevices(GCluster cluster) {
|
||||||
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
|
const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
|
||||||
|
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||||
PyObject* result = PyList_New(devices.size());
|
PyObject* result = PyList_New(devices.size());
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (auto& dev : devices) {
|
for (auto& dev : devices) {
|
||||||
@ -150,17 +177,18 @@ static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster) {
|
|||||||
PyList_SetItem(result, i, dev_obj);
|
PyList_SetItem(result, i, dev_obj);
|
||||||
++i;
|
++i;
|
||||||
}
|
}
|
||||||
|
PyGILState_Release(gstate);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* TF_MeasureCosts(
|
static PyObject* TF_MeasureCosts(
|
||||||
const tensorflow::grappler::GrapplerItem* item,
|
GItem item,
|
||||||
tensorflow::grappler::Cluster* cluster,
|
GCluster cluster,
|
||||||
bool generate_timeline, TF_Status* out_status) {
|
bool generate_timeline, TF_Status* out_status) {
|
||||||
tensorflow::OpPerformanceList op_performance_data;
|
tensorflow::OpPerformanceList op_performance_data;
|
||||||
tensorflow::StepStats step_stats;
|
tensorflow::StepStats step_stats;
|
||||||
|
|
||||||
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster, 10, 0);
|
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), 10, 0);
|
||||||
|
|
||||||
tensorflow::grappler::Costs costs;
|
tensorflow::grappler::Costs costs;
|
||||||
tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
|
tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
|
||||||
@ -184,6 +212,7 @@ static PyObject* TF_MeasureCosts(
|
|||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||||
PyObject* op_perf_objs = PyList_New(
|
PyObject* op_perf_objs = PyList_New(
|
||||||
op_performance_data.op_performance_size());
|
op_performance_data.op_performance_size());
|
||||||
for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
|
for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
|
||||||
@ -211,17 +240,19 @@ static PyObject* TF_MeasureCosts(
|
|||||||
status = tensorflow::Status(tensorflow::error::Code::INTERNAL,
|
status = tensorflow::Status(tensorflow::error::Code::INTERNAL,
|
||||||
"Error setting return tuples.");
|
"Error setting return tuples.");
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||||
Py_RETURN_NONE;
|
Py_INCREF(Py_None);
|
||||||
|
ret = Py_None;
|
||||||
}
|
}
|
||||||
|
PyGILState_Release(gstate);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static PyObject* TF_DeterminePeakMemoryUsage(
|
static PyObject* TF_DeterminePeakMemoryUsage(
|
||||||
const tensorflow::grappler::GrapplerItem* item,
|
GItem item,
|
||||||
tensorflow::grappler::Cluster* cluster,
|
GCluster cluster,
|
||||||
TF_Status* out_status) {
|
TF_Status* out_status) {
|
||||||
if (!item || !cluster) {
|
if (item.is_none() || cluster.is_none()) {
|
||||||
tensorflow::Status status(tensorflow::error::Code::INTERNAL,
|
tensorflow::Status status(tensorflow::error::Code::INTERNAL,
|
||||||
"You need both a cluster and an item to determine peak memory usage");
|
"You need both a cluster and an item to determine peak memory usage");
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||||
@ -231,7 +262,7 @@ static PyObject* TF_DeterminePeakMemoryUsage(
|
|||||||
|
|
||||||
tensorflow::Status status;
|
tensorflow::Status status;
|
||||||
if (cluster->DetailedStatsEnabled()) {
|
if (cluster->DetailedStatsEnabled()) {
|
||||||
status = memory.InferDynamically(cluster);
|
status = memory.InferDynamically(cluster.get());
|
||||||
} else {
|
} else {
|
||||||
status = memory.InferStatically(cluster->GetDevices());
|
status = memory.InferStatically(cluster->GetDevices());
|
||||||
}
|
}
|
||||||
@ -240,6 +271,7 @@ static PyObject* TF_DeterminePeakMemoryUsage(
|
|||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||||
PyObject* result = PyDict_New();
|
PyObject* result = PyDict_New();
|
||||||
for (const auto& device : cluster->GetDevices()) {
|
for (const auto& device : cluster->GetDevices()) {
|
||||||
const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
|
const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
|
||||||
@ -261,24 +293,24 @@ static PyObject* TF_DeterminePeakMemoryUsage(
|
|||||||
PyTuple_SetItem(ret, 1, per_device);
|
PyTuple_SetItem(ret, 1, per_device);
|
||||||
PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
|
PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
|
||||||
}
|
}
|
||||||
|
PyGILState_Release(gstate);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
%}
|
%}
|
||||||
|
|
||||||
// Wrap these functions.
|
// Wrap these functions.
|
||||||
|
static GCluster TF_NewCluster(
|
||||||
static tensorflow::grappler::Cluster* TF_NewCluster(
|
|
||||||
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status);
|
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status);
|
||||||
static tensorflow::grappler::Cluster* TF_NewVirtualCluster(
|
static GCluster TF_NewVirtualCluster(
|
||||||
const std::vector<tensorflow::NamedDevice>& named_devices,
|
const std::vector<tensorflow::NamedDevice>& named_devices,
|
||||||
TF_Status* out_status);
|
TF_Status* out_status);
|
||||||
static void TF_DeleteCluster(tensorflow::grappler::Cluster* cluster);
|
static void TF_ShutdownCluster(GCluster cluster);
|
||||||
static PyObject* TF_ListDevices(tensorflow::grappler::Cluster* cluster);
|
static PyObject* TF_ListDevices(GCluster cluster);
|
||||||
static PyObject* TF_MeasureCosts(
|
static PyObject* TF_MeasureCosts(
|
||||||
const tensorflow::grappler::GrapplerItem* item, tensorflow::grappler::Cluster* cluster,
|
GItem item, GCluster cluster,
|
||||||
bool generate_timeline, TF_Status* out_status);
|
bool generate_timeline, TF_Status* out_status);
|
||||||
static PyObject* TF_DeterminePeakMemoryUsage(
|
static PyObject* TF_DeterminePeakMemoryUsage(
|
||||||
const tensorflow::grappler::GrapplerItem* item, tensorflow::grappler::Cluster* cluster,
|
GItem item, GCluster cluster,
|
||||||
TF_Status* out_status);
|
TF_Status* out_status);
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ class Cluster(object):
|
|||||||
the local machine.
|
the local machine.
|
||||||
"""
|
"""
|
||||||
self._tf_cluster = None
|
self._tf_cluster = None
|
||||||
|
self._generate_timeline = not disable_timeline
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
if devices is None:
|
if devices is None:
|
||||||
self._tf_cluster = tf_cluster.TF_NewCluster(
|
self._tf_cluster = tf_cluster.TF_NewCluster(
|
||||||
@ -54,11 +55,10 @@ class Cluster(object):
|
|||||||
devices_serialized = [device.SerializeToString() for device in devices]
|
devices_serialized = [device.SerializeToString() for device in devices]
|
||||||
self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
|
self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
|
||||||
devices_serialized, status)
|
devices_serialized, status)
|
||||||
self._generate_timeline = not disable_timeline
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self._tf_cluster is not None:
|
if self._tf_cluster is not None:
|
||||||
tf_cluster.TF_DeleteCluster(self._tf_cluster)
|
tf_cluster.TF_ShutdownCluster(self._tf_cluster)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tf_cluster(self):
|
def tf_cluster(self):
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
%include "tensorflow/python/lib/core/strings.i"
|
%include "tensorflow/python/lib/core/strings.i"
|
||||||
%include "tensorflow/python/platform/base.i"
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
%include "cluster.i"
|
||||||
|
|
||||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||||
char* c_string;
|
char* c_string;
|
||||||
@ -42,8 +43,8 @@ limitations under the License.
|
|||||||
%}
|
%}
|
||||||
|
|
||||||
%{
|
%{
|
||||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool
|
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
|
||||||
per_node_report, tensorflow::grappler::Cluster* cluster) {
|
GCluster cluster) {
|
||||||
tensorflow::grappler::ItemConfig cfg;
|
tensorflow::grappler::ItemConfig cfg;
|
||||||
cfg.apply_optimizations = false;
|
cfg.apply_optimizations = false;
|
||||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
|
||||||
@ -53,7 +54,7 @@ per_node_report, tensorflow::grappler::Cluster* cluster) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string suffix;
|
string suffix;
|
||||||
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster, suffix);
|
tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix);
|
||||||
|
|
||||||
std::stringstream os;
|
std::stringstream os;
|
||||||
analyzer.GenerateReport(os, per_node_report);
|
analyzer.GenerateReport(os, per_node_report);
|
||||||
@ -62,5 +63,5 @@ per_node_report, tensorflow::grappler::Cluster* cluster) {
|
|||||||
|
|
||||||
%}
|
%}
|
||||||
|
|
||||||
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool
|
string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report,
|
||||||
per_node_report, tensorflow::grappler::Cluster* cluster);
|
GCluster cluster);
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
%include <std_shared_ptr.i>
|
||||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||||
char* c_string;
|
char* c_string;
|
||||||
Py_ssize_t py_size;
|
Py_ssize_t py_size;
|
||||||
@ -30,7 +31,12 @@ limitations under the License.
|
|||||||
$1 = &temp;
|
$1 = &temp;
|
||||||
}
|
}
|
||||||
|
|
||||||
%newobject TF_NewItem;
|
// Wrap the item into an object that swig can manipulate. This ensures it will call the object
|
||||||
|
// destructor upon garbage collection instead of leaking memory.
|
||||||
|
struct GItem {
|
||||||
|
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
%{
|
%{
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
@ -42,8 +48,26 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
static tensorflow::grappler::GrapplerItem* TF_NewItem(
|
// Provide the implementation fo the GItem struct here.
|
||||||
|
struct GItem {
|
||||||
|
GItem() {}
|
||||||
|
GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {}
|
||||||
|
|
||||||
|
tensorflow::grappler::GrapplerItem* operator->() const {
|
||||||
|
return item_.get();
|
||||||
|
}
|
||||||
|
const tensorflow::grappler::GrapplerItem& operator*() const {
|
||||||
|
return *item_.get();
|
||||||
|
}
|
||||||
|
bool is_none() const {
|
||||||
|
return item_.get() == nullptr;
|
||||||
|
}
|
||||||
|
std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
|
||||||
|
};
|
||||||
|
|
||||||
|
static GItem TF_NewItem(
|
||||||
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
||||||
bool ignore_user_placement, TF_Status* out_status) {
|
bool ignore_user_placement, TF_Status* out_status) {
|
||||||
if (meta_graph.collection_def().count("train_op") == 0) {
|
if (meta_graph.collection_def().count("train_op") == 0) {
|
||||||
@ -65,11 +89,11 @@ static tensorflow::grappler::GrapplerItem* TF_NewItem(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
|
tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
|
||||||
return item.release();
|
return GItem(item.release());
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<string> TF_IdentifyImportantOps(const tensorflow::grappler::GrapplerItem* item) {
|
static std::vector<string> TF_IdentifyImportantOps(GItem item) {
|
||||||
if (!item) {
|
if (item.is_none()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,8 +115,8 @@ static std::vector<string> TF_IdentifyImportantOps(const tensorflow::grappler::G
|
|||||||
return ops;
|
return ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* item) {
|
static PyObject* TF_GetOpProperties(GItem item) {
|
||||||
if (!item) {
|
if (item.is_none()) {
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
tensorflow::grappler::GraphProperties properties(*item);
|
tensorflow::grappler::GraphProperties properties(*item);
|
||||||
@ -101,6 +125,7 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it
|
|||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||||
PyObject* props = PyDict_New();
|
PyObject* props = PyDict_New();
|
||||||
for (const auto& node : item->graph.node()) {
|
for (const auto& node : item->graph.node()) {
|
||||||
const string& node_name = node.name();
|
const string& node_name = node.name();
|
||||||
@ -115,8 +140,8 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it
|
|||||||
PyList_SetItem(prop, i, output_prop);
|
PyList_SetItem(prop, i, output_prop);
|
||||||
}
|
}
|
||||||
CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
|
CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
|
||||||
}
|
}
|
||||||
|
PyGILState_Release(gstate);
|
||||||
return props;
|
return props;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,8 +149,8 @@ static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* it
|
|||||||
|
|
||||||
|
|
||||||
// Wrap these functions.
|
// Wrap these functions.
|
||||||
static tensorflow::grappler::GrapplerItem* TF_NewItem(
|
static GItem TF_NewItem(
|
||||||
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
||||||
bool ignore_user_placement, TF_Status* out_status);
|
bool ignore_user_placement, TF_Status* out_status);
|
||||||
static std::vector<string> TF_IdentifyImportantOps(const tensorflow::grappler::GrapplerItem* item);
|
static std::vector<string> TF_IdentifyImportantOps(GItem item);
|
||||||
static PyObject* TF_GetOpProperties(const tensorflow::grappler::GrapplerItem* item);
|
static PyObject* TF_GetOpProperties(GItem item);
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
|
|
||||||
%include "tensorflow/python/platform/base.i"
|
%include "tensorflow/python/platform/base.i"
|
||||||
|
%include "cluster.i"
|
||||||
|
|
||||||
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
|
||||||
char* c_string;
|
char* c_string;
|
||||||
@ -92,7 +93,7 @@ void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* dev
|
|||||||
}
|
}
|
||||||
|
|
||||||
PyObject* TF_OptimizeGraph(
|
PyObject* TF_OptimizeGraph(
|
||||||
tensorflow::grappler::Cluster* cluster,
|
GCluster cluster,
|
||||||
const tensorflow::RewriterConfig& rewriter_config,
|
const tensorflow::RewriterConfig& rewriter_config,
|
||||||
const tensorflow::MetaGraphDef& metagraph,
|
const tensorflow::MetaGraphDef& metagraph,
|
||||||
bool verbose, const string& graph_id, TF_Status* out_status) {
|
bool verbose, const string& graph_id, TF_Status* out_status) {
|
||||||
@ -102,17 +103,10 @@ PyObject* TF_OptimizeGraph(
|
|||||||
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
|
||||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
|
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::grappler::VirtualCluster> virtual_cluster;
|
|
||||||
if (cluster == nullptr) {
|
|
||||||
std::unordered_map<string, tensorflow::DeviceProperties> device_map;
|
|
||||||
DetectDevices(&device_map);
|
|
||||||
virtual_cluster.reset(new tensorflow::grappler::VirtualCluster(device_map));
|
|
||||||
cluster = virtual_cluster.get();
|
|
||||||
}
|
|
||||||
tensorflow::DeviceBase* cpu_device = nullptr;
|
tensorflow::DeviceBase* cpu_device = nullptr;
|
||||||
tensorflow::GraphDef out_graph;
|
tensorflow::GraphDef out_graph;
|
||||||
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config);
|
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config);
|
||||||
tensorflow::Status status = optimizer.Optimize(cluster, *grappler_item, &out_graph);
|
tensorflow::Status status = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
optimizer.PrintResult();
|
optimizer.PrintResult();
|
||||||
}
|
}
|
||||||
@ -127,7 +121,7 @@ PyObject* TF_OptimizeGraph(
|
|||||||
|
|
||||||
// Wrap this function
|
// Wrap this function
|
||||||
PyObject* TF_OptimizeGraph(
|
PyObject* TF_OptimizeGraph(
|
||||||
tensorflow::grappler::Cluster* cluster,
|
GCluster cluster,
|
||||||
const tensorflow::RewriterConfig& rewriter_config,
|
const tensorflow::RewriterConfig& rewriter_config,
|
||||||
const tensorflow::MetaGraphDef& metagraph, bool verbose,
|
const tensorflow::MetaGraphDef& metagraph, bool verbose,
|
||||||
const string& graph_id, TF_Status* out_status);
|
const string& graph_id, TF_Status* out_status);
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.python import pywrap_tensorflow as tf_opt
|
from tensorflow.python import pywrap_tensorflow as tf_opt
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.grappler import cluster as gcluster
|
||||||
|
|
||||||
|
|
||||||
def OptimizeGraph(rewriter_config,
|
def OptimizeGraph(rewriter_config,
|
||||||
@ -30,8 +31,9 @@ def OptimizeGraph(rewriter_config,
|
|||||||
cluster=None):
|
cluster=None):
|
||||||
"""Optimize the provided metagraph."""
|
"""Optimize the provided metagraph."""
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
ret_from_swig = tf_opt.TF_OptimizeGraph(None if cluster is None else
|
if cluster is None:
|
||||||
cluster.tf_cluster,
|
cluster = gcluster.Cluster()
|
||||||
|
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
|
||||||
rewriter_config.SerializeToString(),
|
rewriter_config.SerializeToString(),
|
||||||
metagraph.SerializeToString(),
|
metagraph.SerializeToString(),
|
||||||
verbose, graph_id, status)
|
verbose, graph_id, status)
|
||||||
|
@ -565,33 +565,34 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
|
|||||||
if not _like_rnncell(cell):
|
if not _like_rnncell(cell):
|
||||||
raise TypeError("cell must be an instance of RNNCell")
|
raise TypeError("cell must be an instance of RNNCell")
|
||||||
|
|
||||||
# By default, time_major==False and inputs are batch-major: shaped
|
|
||||||
# [batch, time, depth]
|
|
||||||
# For internal calculations, we transpose to [time, batch, depth]
|
|
||||||
flat_input = nest.flatten(inputs)
|
|
||||||
|
|
||||||
if not time_major:
|
|
||||||
# (B,T,D) => (T,B,D)
|
|
||||||
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
|
|
||||||
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
|
|
||||||
|
|
||||||
parallel_iterations = parallel_iterations or 32
|
|
||||||
if sequence_length is not None:
|
|
||||||
sequence_length = math_ops.to_int32(sequence_length)
|
|
||||||
if sequence_length.get_shape().ndims not in (None, 1):
|
|
||||||
raise ValueError(
|
|
||||||
"sequence_length must be a vector of length batch_size, "
|
|
||||||
"but saw shape: %s" % sequence_length.get_shape())
|
|
||||||
sequence_length = array_ops.identity( # Just to find it in the graph.
|
|
||||||
sequence_length, name="sequence_length")
|
|
||||||
|
|
||||||
# Create a new scope in which the caching device is either
|
|
||||||
# determined by the parent scope, or is set to place the cached
|
|
||||||
# Variable using the same placement as for the rest of the RNN.
|
|
||||||
with vs.variable_scope(scope or "rnn") as varscope:
|
with vs.variable_scope(scope or "rnn") as varscope:
|
||||||
|
# Create a new scope in which the caching device is either
|
||||||
|
# determined by the parent scope, or is set to place the cached
|
||||||
|
# Variable using the same placement as for the rest of the RNN.
|
||||||
if context.in_graph_mode():
|
if context.in_graph_mode():
|
||||||
if varscope.caching_device is None:
|
if varscope.caching_device is None:
|
||||||
varscope.set_caching_device(lambda op: op.device)
|
varscope.set_caching_device(lambda op: op.device)
|
||||||
|
|
||||||
|
# By default, time_major==False and inputs are batch-major: shaped
|
||||||
|
# [batch, time, depth]
|
||||||
|
# For internal calculations, we transpose to [time, batch, depth]
|
||||||
|
flat_input = nest.flatten(inputs)
|
||||||
|
|
||||||
|
if not time_major:
|
||||||
|
# (B,T,D) => (T,B,D)
|
||||||
|
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
|
||||||
|
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
|
||||||
|
|
||||||
|
parallel_iterations = parallel_iterations or 32
|
||||||
|
if sequence_length is not None:
|
||||||
|
sequence_length = math_ops.to_int32(sequence_length)
|
||||||
|
if sequence_length.get_shape().ndims not in (None, 1):
|
||||||
|
raise ValueError(
|
||||||
|
"sequence_length must be a vector of length batch_size, "
|
||||||
|
"but saw shape: %s" % sequence_length.get_shape())
|
||||||
|
sequence_length = array_ops.identity( # Just to find it in the graph.
|
||||||
|
sequence_length, name="sequence_length")
|
||||||
|
|
||||||
batch_size = _best_effort_input_batch_size(flat_input)
|
batch_size = _best_effort_input_batch_size(flat_input)
|
||||||
|
|
||||||
if initial_state is not None:
|
if initial_state is not None:
|
||||||
|
@ -108,11 +108,6 @@ static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase *gpu_mem) {
|
|||||||
return AsCudaDevicePtr(*gpu_mem);
|
return AsCudaDevicePtr(*gpu_mem);
|
||||||
}
|
}
|
||||||
|
|
||||||
static CudaContext* GetCudaContext(Stream *stream) {
|
|
||||||
return static_cast<CUDAExecutor *>(stream->parent()->implementation())
|
|
||||||
->cuda_context();
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) {
|
CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) {
|
||||||
CHECK(cuda_exec != nullptr);
|
CHECK(cuda_exec != nullptr);
|
||||||
return cuda_exec->cuda_context();
|
return cuda_exec->cuda_context();
|
||||||
@ -380,11 +375,11 @@ bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
|
|||||||
|
|
||||||
void **kernel_params = const_cast<void **>(args.argument_addresses().data());
|
void **kernel_params = const_cast<void **>(args.argument_addresses().data());
|
||||||
|
|
||||||
if (!CUDADriver::LaunchKernel(GetCudaContext(stream), cufunc, block_dims.x,
|
if (!CUDADriver::LaunchKernel(context_, cufunc, block_dims.x, block_dims.y,
|
||||||
block_dims.y, block_dims.z, thread_dims.x,
|
block_dims.z, thread_dims.x, thread_dims.y,
|
||||||
thread_dims.y, thread_dims.z,
|
thread_dims.z, args.number_of_shared_bytes(),
|
||||||
args.number_of_shared_bytes(), custream,
|
custream, kernel_params,
|
||||||
kernel_params, nullptr /* = extra */)) {
|
nullptr /* = extra */)) {
|
||||||
LOG(ERROR) << "failed to launch CUDA kernel with args: "
|
LOG(ERROR) << "failed to launch CUDA kernel with args: "
|
||||||
<< args.number_of_arguments()
|
<< args.number_of_arguments()
|
||||||
<< "; thread dim: " << thread_dims.ToString()
|
<< "; thread dim: " << thread_dims.ToString()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user