commit
b15caeedf6
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -54,11 +55,18 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
<< tensorflow::str_util::Join(disabled_passes, ", ");
|
||||
}
|
||||
|
||||
auto run_invariant_checkers = [this, module]() -> Status {
|
||||
auto run_invariant_checkers = [this,
|
||||
module](const string& message) -> Status {
|
||||
for (auto& invariant_checker : invariant_checkers_) {
|
||||
VLOG(1) << " Invariant checker " << invariant_checker->name();
|
||||
TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module));
|
||||
TF_RET_CHECK(!changed) << "invariant checkers must not change the graph";
|
||||
StatusOr<bool> changed_status = invariant_checker->Run(module);
|
||||
if (!changed_status.ok()) {
|
||||
return Status(changed_status.status().code(),
|
||||
StrCat(changed_status.status().error_message(),
|
||||
"\n\nFailed ", message));
|
||||
}
|
||||
TF_RET_CHECK(!changed_status.ValueOrDie())
|
||||
<< "invariant checkers must not change the graph";
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
@ -66,6 +74,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
string prefix = name().ToString() + ": pipeline start";
|
||||
bool changed = false;
|
||||
string message;
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_invariant_checkers(StrCat("before running pipeline: ", name())));
|
||||
for (auto& pass : passes_) {
|
||||
if (disabled_passes.count(pass->name().ToString()) > 0) {
|
||||
VLOG(1) << " Skipping HLO pass " << pass->name()
|
||||
@ -80,14 +90,14 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
StrAppend(&message, prefix, ", before ", pass->name());
|
||||
DumpModule(*module, message);
|
||||
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
|
||||
TF_RETURN_IF_ERROR(
|
||||
run_invariant_checkers(StrCat("after running pass: ", pass->name())));
|
||||
|
||||
changed |= changed_this_pass;
|
||||
prefix.clear();
|
||||
StrAppend(&prefix, name(), ": after ", pass->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
DumpModule(*module, prefix + ", pipeline end");
|
||||
return changed;
|
||||
}
|
||||
|
@ -1202,7 +1202,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
|
||||
StatusOr<bool> HloRematerialization::Run(
|
||||
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
int64 memory_limit_bytes) {
|
||||
int64 memory_limit_bytes, RematerializationSizes* sizes) {
|
||||
// The sequence is constructed entirely by this method.
|
||||
TF_RET_CHECK(sequence->empty());
|
||||
|
||||
@ -1319,13 +1319,20 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
<< HumanReadableNumBytes(reduced_peak_memory) << " ("
|
||||
<< reduced_peak_memory << " bytes)";
|
||||
|
||||
if (sizes != nullptr) {
|
||||
sizes->before_bytes = before_peak_memory;
|
||||
sizes->after_bytes = current_peak_memory;
|
||||
}
|
||||
|
||||
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
|
||||
|
||||
if (current_peak_memory > memory_limit_bytes) {
|
||||
LOG(WARNING) << "Can't reduce memory use below "
|
||||
<< HumanReadableNumBytes(memory_limit_bytes)
|
||||
<< " by rematerialization (only reduced to "
|
||||
<< HumanReadableNumBytes(current_peak_memory) << ")";
|
||||
LOG(WARNING) << tensorflow::strings::Printf(
|
||||
"Can't reduce memory use below %s (%lld bytes) by rematerialization; "
|
||||
"only reduced to %s (%lld bytes)",
|
||||
HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes,
|
||||
HumanReadableNumBytes(current_peak_memory).c_str(),
|
||||
current_peak_memory);
|
||||
}
|
||||
|
||||
return changed;
|
||||
@ -1334,9 +1341,10 @@ StatusOr<bool> HloRematerialization::Run(
|
||||
/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
|
||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||
int64 memory_limit_bytes, HloModule* hlo_module,
|
||||
SequentialHloOrdering::HloModuleSequence* sequence) {
|
||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
RematerializationSizes* sizes) {
|
||||
HloRematerialization remat(size_function);
|
||||
return remat.Run(hlo_module, sequence, memory_limit_bytes);
|
||||
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -28,6 +28,13 @@ class HloRematerialization {
|
||||
public:
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
|
||||
// Helper struct that communicates the before / after sizes for the
|
||||
// rematerialization process.
|
||||
struct RematerializationSizes {
|
||||
int64 before_bytes;
|
||||
int64 after_bytes;
|
||||
};
|
||||
|
||||
// Rematerialize HLO instructions in the given module to reduce peak memory
|
||||
// use below memory_limit_bytes where memory use is defined as the total size
|
||||
// of all live HLO instruction values. Parameters and constants are included
|
||||
@ -46,6 +53,9 @@ class HloRematerialization {
|
||||
// rematerialization. This is the order in which HLO instructions should
|
||||
// be emitted to minimize memory use.
|
||||
//
|
||||
// sizes: Optional outparam that indicates the peak memory usage of the HLO
|
||||
// module before/after rematerialization.
|
||||
//
|
||||
// Returns whether any instructions were rematerialized. If memory use is
|
||||
// already below the given limit then no instructions are rematerialized and
|
||||
// false is returned.
|
||||
@ -55,8 +65,8 @@ class HloRematerialization {
|
||||
// code generation.
|
||||
static StatusOr<bool> RematerializeAndSchedule(
|
||||
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
|
||||
HloModule* hlo_module,
|
||||
SequentialHloOrdering::HloModuleSequence* sequence);
|
||||
HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
RematerializationSizes* sizes = nullptr);
|
||||
|
||||
protected:
|
||||
HloRematerialization(const ShapeSizeFunction& size_function)
|
||||
@ -69,7 +79,7 @@ class HloRematerialization {
|
||||
// contains the memory-minimizing order in which to emit the HLO instructions.
|
||||
StatusOr<bool> Run(HloModule* module,
|
||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||
int64 memory_limit);
|
||||
int64 memory_limit, RematerializationSizes* sizes);
|
||||
|
||||
// Rematerializes instructions within the given computation. 'order' is the
|
||||
// order in which the computation's instructions will be emitted in the
|
||||
|
@ -123,7 +123,8 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion &&
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
(instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) {
|
||||
// Insert the reduce-precision operation inside the fusion computation,
|
||||
// after the corresponding parameter instruction.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -173,7 +174,8 @@ StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion &&
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
(instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
|
||||
instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) {
|
||||
// Insert the reduce-precision operation as the last operation inside
|
||||
// the fusion computation.
|
||||
HloInstruction* fusion_root = instruction->fused_expression_root();
|
||||
|
@ -240,6 +240,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks.
|
||||
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py"
|
||||
|
@ -166,10 +166,11 @@ struct Expector<T, false> {
|
||||
static void Equal(const Tensor& x, const Tensor& y) {
|
||||
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
|
||||
AssertSameTypeDims(x, y);
|
||||
auto a = x.flat<T>();
|
||||
auto b = y.flat<T>();
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
ExpectEqual(a(i), b(i));
|
||||
const auto size = x.NumElements();
|
||||
const T* a = x.flat<T>().data();
|
||||
const T* b = y.flat<T>().data();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ExpectEqual(a[i], b[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -182,10 +183,11 @@ struct Expector<T, true> {
|
||||
static void Equal(const Tensor& x, const Tensor& y) {
|
||||
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
|
||||
AssertSameTypeDims(x, y);
|
||||
auto a = x.flat<T>();
|
||||
auto b = y.flat<T>();
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
ExpectEqual(a(i), b(i));
|
||||
const auto size = x.NumElements();
|
||||
const T* a = x.flat<T>().data();
|
||||
const T* b = y.flat<T>().data();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ExpectEqual(a[i], b[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,10 +201,11 @@ struct Expector<T, true> {
|
||||
static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
|
||||
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
|
||||
AssertSameTypeDims(x, y);
|
||||
auto a = x.flat<T>();
|
||||
auto b = y.flat<T>();
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
Near(a(i), b(i), abs_err, i);
|
||||
const auto size = x.NumElements();
|
||||
const T* a = x.flat<T>().data();
|
||||
const T* b = y.flat<T>().data();
|
||||
for (int i = 0; i < size; ++i) {
|
||||
Near(a[i], b[i], abs_err, i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -529,7 +529,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
const int64 out_depths = GetTensorDim(*output, data_format, 'C');
|
||||
const int64 patch_rows = filter.dim_size(0);
|
||||
const int64 patch_cols = filter.dim_size(1);
|
||||
if (padding == Eigen::PADDING_SAME) {
|
||||
if (padding == SAME) {
|
||||
// Total padding on rows and cols is
|
||||
// Pr = (R' - 1) * S + Kr - R
|
||||
// Pc = (C' - 1) * S + Kc - C
|
||||
|
@ -91,7 +91,14 @@ class CTCLossOp : public OpKernel {
|
||||
OP_REQUIRES(ctx, batch_size != 0,
|
||||
errors::InvalidArgument("batch_size must not be 0"));
|
||||
|
||||
TensorShape labels_shape({batch_size, max_time});
|
||||
// Figure out the maximum label length to use as sparse tensor dimension.
|
||||
auto labels_indices_t = labels_indices->matrix<int64>();
|
||||
int64 max_label_len = 0;
|
||||
for (int i = 0; i < labels_indices->dim_size(0); i++) {
|
||||
max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1);
|
||||
}
|
||||
|
||||
TensorShape labels_shape({batch_size, max_label_len});
|
||||
std::vector<int64> order{0, 1};
|
||||
sparse::SparseTensor labels_sp(*labels_indices, *labels_values,
|
||||
labels_shape, order);
|
||||
|
@ -392,6 +392,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
self._cached_session = None
|
||||
|
||||
def setUp(self):
|
||||
logging.info("SET UP: %s" % str(self))
|
||||
self._ClearCachedSession()
|
||||
random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
|
||||
@ -406,6 +407,7 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED
|
||||
|
||||
def tearDown(self):
|
||||
logging.info("TEAR DOWN: %s" % str(self))
|
||||
for thread in self._threads:
|
||||
self.assertFalse(thread.is_alive(), "A checkedThread did not terminate")
|
||||
|
||||
|
@ -505,6 +505,8 @@ class ResourceVariable(variables.Variable):
|
||||
def sparse_read(self, indices, name=None):
|
||||
"""Reads the value of this variable sparsely, using `gather`."""
|
||||
with ops.name_scope("Gather" if name is None else name) as name:
|
||||
if self._trainable:
|
||||
tape.watch(self._handle)
|
||||
value = resource_gather(
|
||||
self._handle, indices, dtype=self._dtype, name=name)
|
||||
return array_ops.identity(value)
|
||||
|
@ -699,9 +699,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
native.http_archive(
|
||||
name = "bazel_toolchains",
|
||||
urls = [
|
||||
"http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz",
|
||||
"https://github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz",
|
||||
"http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/9dbd803ad3b9447430a296810197b09b3a710956.tar.gz",
|
||||
"https://github.com/bazelbuild/bazel-toolchains/archive/9dbd803ad3b9447430a296810197b09b3a710956.tar.gz",
|
||||
],
|
||||
sha256 = "3903fd93b96b42067e00b7973a2c16c34e761ad7a0b55e1557d408f352849e41",
|
||||
strip_prefix = "bazel-toolchains-bccee4855c049d34bac481083b4c68e2fab8cc50",
|
||||
sha256 = "0799aa12db5260a499beb40f81744e760c59d055bfc5d271dd2c2ed4d5419faa",
|
||||
strip_prefix = "bazel-toolchains-9dbd803ad3b9447430a296810197b09b3a710956",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user