Fix TensorListFromTensor + XLA compile constant folding bug.

PiperOrigin-RevId: 249728378
This commit is contained in:
Brian Patton 2019-05-23 15:31:49 -07:00 committed by TensorFlower Gardener
parent 310a4a4419
commit fc3414b891
6 changed files with 148 additions and 70 deletions

View File

@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import session
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit()
def testCondAndTensorArrayInDefun_constFolding(self):
g = ops.Graph()
with session.Session(graph=g), g.as_default(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@function.defun
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(False),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t = f()
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testCondAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.cond(
constant_op.constant(True),
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([5.], self.evaluate(output_t))
xla_context.Exit()
def testCondConstPropagation(self):
with self.session() as sess, self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
xla_context.Exit()
def testSwitchCaseAndTensorArray_xlaCompile(self):
self.skipTest("b/127846988")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
def f():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
output = control_flow_ops.switch_case(
constant_op.constant(1), {
0: lambda: ta.write(0, 5.),
1: lambda: ta.write(0, 10.),
2: lambda: ta.write(0, 15.),
})
return output.stack()
output_t, = xla.compile(f)
self.assertAllEqual([10.], self.evaluate(output_t))
xla_context.Exit()
def testSwitchCaseConstPropagation(self):
self.skipTest("b/127846988")
with self.session() as sess, self.test_scope():

View File

@ -114,70 +114,6 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
}
}
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
Allocator* out_allocator, StringPiece edge_name,
Device* src, Tensor* output,
DeviceContext* send_dev_context, StatusCallback done) {
if (input->dtype() == DT_VARIANT) {
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
auto* status_cb = new ReffedStatusCallback(std::move(done));
core::ScopedUnref status_cb_unref(status_cb);
auto wrapped_done = [status_cb](const Status& s) {
status_cb->UpdateStatus(s);
status_cb->Unref();
};
auto copier = std::bind(
[edge_name, src, send_dev_context, out_allocator, status_cb,
cpu_allocator](StatusCallback wrapped_done_,
// Begin unbound arguments
const Tensor& from, Tensor* to) {
if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
if (!DMAHelper::CanUseDMA(&from)) {
Status err = errors::InvalidArgument(
"During Variant Device->Host Copy: "
"non-DMA-copy attempted of tensor type: ",
DataTypeString(from.dtype()));
status_cb->UpdateStatus(err);
return err;
}
if (status_cb->ok()) {
status_cb->Ref();
*to = Tensor(out_allocator, from.dtype(), from.shape());
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
wrapped_done_);
return Status::OK();
} else {
return status_cb->status();
}
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
const Variant* v = input->flat<Variant>().data();
Variant* v_out = copy.flat<Variant>().data();
Status s_copy_init;
for (int64 i = 0; i < input->NumElements(); ++i) {
s_copy_init = VariantDeviceCopy(
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
if (!s_copy_init.ok()) {
status_cb->UpdateStatus(s_copy_init);
break;
}
}
if (s_copy_init.ok()) {
*output = std::move(copy);
}
} else {
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
std::move(done));
}
}
void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
Allocator* cpu_allocator, Allocator* out_allocator,
@ -390,4 +326,69 @@ REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
} // namespace
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
Allocator* out_allocator, StringPiece edge_name,
Device* src, Tensor* output,
DeviceContext* send_dev_context, StatusCallback done) {
if (input->dtype() == DT_VARIANT) {
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
auto* status_cb = new ReffedStatusCallback(std::move(done));
core::ScopedUnref status_cb_unref(status_cb);
auto wrapped_done = [status_cb](const Status& s) {
status_cb->UpdateStatus(s);
status_cb->Unref();
};
auto copier = std::bind(
[edge_name, src, send_dev_context, out_allocator, status_cb,
cpu_allocator](StatusCallback wrapped_done_,
// Begin unbound arguments
const Tensor& from, Tensor* to) {
if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
if (!DMAHelper::CanUseDMA(&from)) {
Status err = errors::InvalidArgument(
"During Variant Device->Host Copy: "
"non-DMA-copy attempted of tensor type: ",
DataTypeString(from.dtype()));
status_cb->UpdateStatus(err);
return err;
}
if (status_cb->ok()) {
status_cb->Ref();
*to = Tensor(out_allocator, from.dtype(), from.shape());
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
wrapped_done_);
return Status::OK();
} else {
return status_cb->status();
}
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
const Variant* v = input->flat<Variant>().data();
Variant* v_out = copy.flat<Variant>().data();
Status s_copy_init;
for (int64 i = 0; i < input->NumElements(); ++i) {
s_copy_init = VariantDeviceCopy(
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
if (!s_copy_init.ok()) {
status_cb->UpdateStatus(s_copy_init);
break;
}
}
if (s_copy_init.ok()) {
*output = std::move(copy);
}
} else {
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
std::move(done));
}
}
} // namespace tensorflow

View File

@ -69,6 +69,11 @@ class CopyTensor {
CopyFunction copy_function);
};
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
Allocator* out_allocator, StringPiece edge_name,
Device* src, Tensor* output,
DeviceContext* send_dev_context, StatusCallback done);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_

View File

@ -100,7 +100,7 @@ Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src,
}
if (!DMAHelper::CanUseDMA(&src)) {
return errors::Internal("GPU copy from non-DMA ",
DataTypeString(src.dtype()), "tensor");
DataTypeString(src.dtype()), " tensor");
}
return Status::OK();
}

View File

@ -22,9 +22,9 @@ limitations under the License.
#include "grpcpp/alarm.h"
#include "grpcpp/server_builder.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@ -545,8 +545,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
delete copy;
};
send_dev_context->CopyDeviceTensorToCPU(
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
src_dev, copy, send_dev_context, copy_ready);
return;
}
}

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@ -896,6 +897,13 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (op_def->output_arg_size() == 0) {
return false;
}
// Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
// TODO(rmlarsen): Only do this for XLA_* devices.
for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
if (output_arg.type() == DT_VARIANT) {
return false;
}
}
// No need to (and don't) fold nodes that have no outgoing edges except
// whitelisted nodes. Such nodes could be introduced by an earlier constant