Fix TensorListFromTensor + XLA compile constant folding bug.
PiperOrigin-RevId: 249728378
This commit is contained in:
parent
310a4a4419
commit
fc3414b891
@ -19,11 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(
|
||||
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
constant_op.constant(True),
|
||||
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
|
||||
return output.stack()
|
||||
|
||||
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testCondAndTensorArrayInDefun_constFolding(self):
|
||||
g = ops.Graph()
|
||||
with session.Session(graph=g), g.as_default(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
@function.defun
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(False),
|
||||
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
|
||||
return output.stack()
|
||||
|
||||
output_t = f()
|
||||
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testCondAndTensorArray_xlaCompile(self):
|
||||
self.skipTest("b/127846988")
|
||||
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
|
||||
with self.session(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.cond(
|
||||
constant_op.constant(True),
|
||||
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||
|
||||
return output.stack()
|
||||
|
||||
output_t, = xla.compile(f)
|
||||
self.assertAllEqual([5.], self.evaluate(output_t))
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testCondConstPropagation(self):
|
||||
with self.session() as sess, self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testSwitchCaseAndTensorArray_xlaCompile(self):
|
||||
self.skipTest("b/127846988")
|
||||
with self.session(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
|
||||
def f():
|
||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||
output = control_flow_ops.switch_case(
|
||||
constant_op.constant(1), {
|
||||
0: lambda: ta.write(0, 5.),
|
||||
1: lambda: ta.write(0, 10.),
|
||||
2: lambda: ta.write(0, 15.),
|
||||
})
|
||||
|
||||
return output.stack()
|
||||
|
||||
output_t, = xla.compile(f)
|
||||
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||
|
||||
xla_context.Exit()
|
||||
|
||||
def testSwitchCaseConstPropagation(self):
|
||||
self.skipTest("b/127846988")
|
||||
with self.session() as sess, self.test_scope():
|
||||
|
@ -114,70 +114,6 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
||||
}
|
||||
}
|
||||
|
||||
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
|
||||
Allocator* out_allocator, StringPiece edge_name,
|
||||
Device* src, Tensor* output,
|
||||
DeviceContext* send_dev_context, StatusCallback done) {
|
||||
if (input->dtype() == DT_VARIANT) {
|
||||
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
|
||||
auto* status_cb = new ReffedStatusCallback(std::move(done));
|
||||
core::ScopedUnref status_cb_unref(status_cb);
|
||||
|
||||
auto wrapped_done = [status_cb](const Status& s) {
|
||||
status_cb->UpdateStatus(s);
|
||||
status_cb->Unref();
|
||||
};
|
||||
auto copier = std::bind(
|
||||
[edge_name, src, send_dev_context, out_allocator, status_cb,
|
||||
cpu_allocator](StatusCallback wrapped_done_,
|
||||
// Begin unbound arguments
|
||||
const Tensor& from, Tensor* to) {
|
||||
if (from.dtype() == DT_VARIANT) {
|
||||
status_cb->Ref();
|
||||
CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
|
||||
src, to, send_dev_context, wrapped_done_);
|
||||
return Status::OK();
|
||||
} else {
|
||||
if (!DMAHelper::CanUseDMA(&from)) {
|
||||
Status err = errors::InvalidArgument(
|
||||
"During Variant Device->Host Copy: "
|
||||
"non-DMA-copy attempted of tensor type: ",
|
||||
DataTypeString(from.dtype()));
|
||||
status_cb->UpdateStatus(err);
|
||||
return err;
|
||||
}
|
||||
if (status_cb->ok()) {
|
||||
status_cb->Ref();
|
||||
*to = Tensor(out_allocator, from.dtype(), from.shape());
|
||||
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
|
||||
wrapped_done_);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return status_cb->status();
|
||||
}
|
||||
}
|
||||
},
|
||||
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
|
||||
|
||||
const Variant* v = input->flat<Variant>().data();
|
||||
Variant* v_out = copy.flat<Variant>().data();
|
||||
Status s_copy_init;
|
||||
for (int64 i = 0; i < input->NumElements(); ++i) {
|
||||
s_copy_init = VariantDeviceCopy(
|
||||
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
|
||||
if (!s_copy_init.ok()) {
|
||||
status_cb->UpdateStatus(s_copy_init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (s_copy_init.ok()) {
|
||||
*output = std::move(copy);
|
||||
}
|
||||
} else {
|
||||
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
|
||||
std::move(done));
|
||||
}
|
||||
}
|
||||
|
||||
void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
|
||||
Allocator* cpu_allocator, Allocator* out_allocator,
|
||||
@ -390,4 +326,69 @@ REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||
|
||||
} // namespace
|
||||
|
||||
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
|
||||
Allocator* out_allocator, StringPiece edge_name,
|
||||
Device* src, Tensor* output,
|
||||
DeviceContext* send_dev_context, StatusCallback done) {
|
||||
if (input->dtype() == DT_VARIANT) {
|
||||
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
|
||||
auto* status_cb = new ReffedStatusCallback(std::move(done));
|
||||
core::ScopedUnref status_cb_unref(status_cb);
|
||||
|
||||
auto wrapped_done = [status_cb](const Status& s) {
|
||||
status_cb->UpdateStatus(s);
|
||||
status_cb->Unref();
|
||||
};
|
||||
auto copier = std::bind(
|
||||
[edge_name, src, send_dev_context, out_allocator, status_cb,
|
||||
cpu_allocator](StatusCallback wrapped_done_,
|
||||
// Begin unbound arguments
|
||||
const Tensor& from, Tensor* to) {
|
||||
if (from.dtype() == DT_VARIANT) {
|
||||
status_cb->Ref();
|
||||
CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
|
||||
src, to, send_dev_context, wrapped_done_);
|
||||
return Status::OK();
|
||||
} else {
|
||||
if (!DMAHelper::CanUseDMA(&from)) {
|
||||
Status err = errors::InvalidArgument(
|
||||
"During Variant Device->Host Copy: "
|
||||
"non-DMA-copy attempted of tensor type: ",
|
||||
DataTypeString(from.dtype()));
|
||||
status_cb->UpdateStatus(err);
|
||||
return err;
|
||||
}
|
||||
if (status_cb->ok()) {
|
||||
status_cb->Ref();
|
||||
*to = Tensor(out_allocator, from.dtype(), from.shape());
|
||||
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
|
||||
wrapped_done_);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return status_cb->status();
|
||||
}
|
||||
}
|
||||
},
|
||||
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
|
||||
|
||||
const Variant* v = input->flat<Variant>().data();
|
||||
Variant* v_out = copy.flat<Variant>().data();
|
||||
Status s_copy_init;
|
||||
for (int64 i = 0; i < input->NumElements(); ++i) {
|
||||
s_copy_init = VariantDeviceCopy(
|
||||
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
|
||||
if (!s_copy_init.ok()) {
|
||||
status_cb->UpdateStatus(s_copy_init);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (s_copy_init.ok()) {
|
||||
*output = std::move(copy);
|
||||
}
|
||||
} else {
|
||||
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
|
||||
std::move(done));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -69,6 +69,11 @@ class CopyTensor {
|
||||
CopyFunction copy_function);
|
||||
};
|
||||
|
||||
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
|
||||
Allocator* out_allocator, StringPiece edge_name,
|
||||
Device* src, Tensor* output,
|
||||
DeviceContext* send_dev_context, StatusCallback done);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_
|
||||
|
@ -100,7 +100,7 @@ Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src,
|
||||
}
|
||||
if (!DMAHelper::CanUseDMA(&src)) {
|
||||
return errors::Internal("GPU copy from non-DMA ",
|
||||
DataTypeString(src.dtype()), "tensor");
|
||||
DataTypeString(src.dtype()), " tensor");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -22,9 +22,9 @@ limitations under the License.
|
||||
|
||||
#include "grpcpp/alarm.h"
|
||||
#include "grpcpp/server_builder.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
@ -545,8 +545,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
delete copy;
|
||||
};
|
||||
|
||||
send_dev_context->CopyDeviceTensorToCPU(
|
||||
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
||||
CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
|
||||
src_dev, copy, send_dev_context, copy_ready);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
@ -896,6 +897,13 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||
if (op_def->output_arg_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
// Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
|
||||
// TODO(rmlarsen): Only do this for XLA_* devices.
|
||||
for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
|
||||
if (output_arg.type() == DT_VARIANT) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// No need to (and don't) fold nodes that have no outgoing edges except
|
||||
// whitelisted nodes. Such nodes could be introduced by an earlier constant
|
||||
|
Loading…
Reference in New Issue
Block a user