Fix TensorListFromTensor + XLA compile constant folding bug.
PiperOrigin-RevId: 249728378
This commit is contained in:
parent
310a4a4419
commit
fc3414b891
@ -19,11 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.compiler.xla import xla
|
from tensorflow.python.compiler.xla import xla
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -46,8 +48,8 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
def f():
|
def f():
|
||||||
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
output = control_flow_ops.cond(
|
output = control_flow_ops.cond(
|
||||||
constant_op.constant(
|
constant_op.constant(True),
|
||||||
True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||||
|
|
||||||
return output.stack()
|
return output.stack()
|
||||||
|
|
||||||
@ -56,6 +58,46 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondAndTensorArrayInDefun_constFolding(self):
|
||||||
|
g = ops.Graph()
|
||||||
|
with session.Session(graph=g), g.as_default(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
@function.defun
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.cond(
|
||||||
|
constant_op.constant(False),
|
||||||
|
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t = f()
|
||||||
|
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testCondAndTensorArray_xlaCompile(self):
|
||||||
|
self.skipTest("b/127846988")
|
||||||
|
# Fails with "Uninitialized arguments" in XlaIfOp::Compile
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.cond(
|
||||||
|
constant_op.constant(True),
|
||||||
|
lambda: ta.write(0, 5.), lambda: ta.write(0, 10.))
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t, = xla.compile(f)
|
||||||
|
self.assertAllEqual([5.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
def testCondConstPropagation(self):
|
def testCondConstPropagation(self):
|
||||||
with self.session() as sess, self.test_scope():
|
with self.session() as sess, self.test_scope():
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
@ -199,6 +241,28 @@ class CondTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
xla_context.Exit()
|
xla_context.Exit()
|
||||||
|
|
||||||
|
def testSwitchCaseAndTensorArray_xlaCompile(self):
|
||||||
|
self.skipTest("b/127846988")
|
||||||
|
with self.session(), self.test_scope():
|
||||||
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
|
xla_context.Enter()
|
||||||
|
|
||||||
|
def f():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
|
||||||
|
output = control_flow_ops.switch_case(
|
||||||
|
constant_op.constant(1), {
|
||||||
|
0: lambda: ta.write(0, 5.),
|
||||||
|
1: lambda: ta.write(0, 10.),
|
||||||
|
2: lambda: ta.write(0, 15.),
|
||||||
|
})
|
||||||
|
|
||||||
|
return output.stack()
|
||||||
|
|
||||||
|
output_t, = xla.compile(f)
|
||||||
|
self.assertAllEqual([10.], self.evaluate(output_t))
|
||||||
|
|
||||||
|
xla_context.Exit()
|
||||||
|
|
||||||
def testSwitchCaseConstPropagation(self):
|
def testSwitchCaseConstPropagation(self):
|
||||||
self.skipTest("b/127846988")
|
self.skipTest("b/127846988")
|
||||||
with self.session() as sess, self.test_scope():
|
with self.session() as sess, self.test_scope():
|
||||||
|
@ -114,70 +114,6 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
|
|
||||||
Allocator* out_allocator, StringPiece edge_name,
|
|
||||||
Device* src, Tensor* output,
|
|
||||||
DeviceContext* send_dev_context, StatusCallback done) {
|
|
||||||
if (input->dtype() == DT_VARIANT) {
|
|
||||||
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
|
|
||||||
auto* status_cb = new ReffedStatusCallback(std::move(done));
|
|
||||||
core::ScopedUnref status_cb_unref(status_cb);
|
|
||||||
|
|
||||||
auto wrapped_done = [status_cb](const Status& s) {
|
|
||||||
status_cb->UpdateStatus(s);
|
|
||||||
status_cb->Unref();
|
|
||||||
};
|
|
||||||
auto copier = std::bind(
|
|
||||||
[edge_name, src, send_dev_context, out_allocator, status_cb,
|
|
||||||
cpu_allocator](StatusCallback wrapped_done_,
|
|
||||||
// Begin unbound arguments
|
|
||||||
const Tensor& from, Tensor* to) {
|
|
||||||
if (from.dtype() == DT_VARIANT) {
|
|
||||||
status_cb->Ref();
|
|
||||||
CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
|
|
||||||
src, to, send_dev_context, wrapped_done_);
|
|
||||||
return Status::OK();
|
|
||||||
} else {
|
|
||||||
if (!DMAHelper::CanUseDMA(&from)) {
|
|
||||||
Status err = errors::InvalidArgument(
|
|
||||||
"During Variant Device->Host Copy: "
|
|
||||||
"non-DMA-copy attempted of tensor type: ",
|
|
||||||
DataTypeString(from.dtype()));
|
|
||||||
status_cb->UpdateStatus(err);
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
if (status_cb->ok()) {
|
|
||||||
status_cb->Ref();
|
|
||||||
*to = Tensor(out_allocator, from.dtype(), from.shape());
|
|
||||||
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
|
|
||||||
wrapped_done_);
|
|
||||||
return Status::OK();
|
|
||||||
} else {
|
|
||||||
return status_cb->status();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
|
|
||||||
|
|
||||||
const Variant* v = input->flat<Variant>().data();
|
|
||||||
Variant* v_out = copy.flat<Variant>().data();
|
|
||||||
Status s_copy_init;
|
|
||||||
for (int64 i = 0; i < input->NumElements(); ++i) {
|
|
||||||
s_copy_init = VariantDeviceCopy(
|
|
||||||
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
|
|
||||||
if (!s_copy_init.ok()) {
|
|
||||||
status_cb->UpdateStatus(s_copy_init);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (s_copy_init.ok()) {
|
|
||||||
*output = std::move(copy);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
|
|
||||||
std::move(done));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
|
void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
|
||||||
Allocator* cpu_allocator, Allocator* out_allocator,
|
Allocator* cpu_allocator, Allocator* out_allocator,
|
||||||
@ -390,4 +326,69 @@ REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
|
||||||
|
Allocator* out_allocator, StringPiece edge_name,
|
||||||
|
Device* src, Tensor* output,
|
||||||
|
DeviceContext* send_dev_context, StatusCallback done) {
|
||||||
|
if (input->dtype() == DT_VARIANT) {
|
||||||
|
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
|
||||||
|
auto* status_cb = new ReffedStatusCallback(std::move(done));
|
||||||
|
core::ScopedUnref status_cb_unref(status_cb);
|
||||||
|
|
||||||
|
auto wrapped_done = [status_cb](const Status& s) {
|
||||||
|
status_cb->UpdateStatus(s);
|
||||||
|
status_cb->Unref();
|
||||||
|
};
|
||||||
|
auto copier = std::bind(
|
||||||
|
[edge_name, src, send_dev_context, out_allocator, status_cb,
|
||||||
|
cpu_allocator](StatusCallback wrapped_done_,
|
||||||
|
// Begin unbound arguments
|
||||||
|
const Tensor& from, Tensor* to) {
|
||||||
|
if (from.dtype() == DT_VARIANT) {
|
||||||
|
status_cb->Ref();
|
||||||
|
CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
|
||||||
|
src, to, send_dev_context, wrapped_done_);
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
if (!DMAHelper::CanUseDMA(&from)) {
|
||||||
|
Status err = errors::InvalidArgument(
|
||||||
|
"During Variant Device->Host Copy: "
|
||||||
|
"non-DMA-copy attempted of tensor type: ",
|
||||||
|
DataTypeString(from.dtype()));
|
||||||
|
status_cb->UpdateStatus(err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
if (status_cb->ok()) {
|
||||||
|
status_cb->Ref();
|
||||||
|
*to = Tensor(out_allocator, from.dtype(), from.shape());
|
||||||
|
send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
|
||||||
|
wrapped_done_);
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return status_cb->status();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
|
||||||
|
|
||||||
|
const Variant* v = input->flat<Variant>().data();
|
||||||
|
Variant* v_out = copy.flat<Variant>().data();
|
||||||
|
Status s_copy_init;
|
||||||
|
for (int64 i = 0; i < input->NumElements(); ++i) {
|
||||||
|
s_copy_init = VariantDeviceCopy(
|
||||||
|
VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier);
|
||||||
|
if (!s_copy_init.ok()) {
|
||||||
|
status_cb->UpdateStatus(s_copy_init);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (s_copy_init.ok()) {
|
||||||
|
*output = std::move(copy);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, output,
|
||||||
|
std::move(done));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -69,6 +69,11 @@ class CopyTensor {
|
|||||||
CopyFunction copy_function);
|
CopyFunction copy_function);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
|
||||||
|
Allocator* out_allocator, StringPiece edge_name,
|
||||||
|
Device* src, Tensor* output,
|
||||||
|
DeviceContext* send_dev_context, StatusCallback done);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_
|
||||||
|
@ -22,9 +22,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "grpcpp/alarm.h"
|
#include "grpcpp/alarm.h"
|
||||||
#include "grpcpp/server_builder.h"
|
#include "grpcpp/server_builder.h"
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
|
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
|
||||||
|
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||||
@ -545,8 +545,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
delete copy;
|
delete copy;
|
||||||
};
|
};
|
||||||
|
|
||||||
send_dev_context->CopyDeviceTensorToCPU(
|
CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
|
||||||
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
src_dev, copy, send_dev_context, copy_ready);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_util.h"
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
@ -896,6 +897,13 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
|||||||
if (op_def->output_arg_size() == 0) {
|
if (op_def->output_arg_size() == 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
|
||||||
|
// TODO(rmlarsen): Only do this for XLA_* devices.
|
||||||
|
for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
|
||||||
|
if (output_arg.type() == DT_VARIANT) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// No need to (and don't) fold nodes that have no outgoing edges except
|
// No need to (and don't) fold nodes that have no outgoing edges except
|
||||||
// whitelisted nodes. Such nodes could be introduced by an earlier constant
|
// whitelisted nodes. Such nodes could be introduced by an earlier constant
|
||||||
|
Loading…
Reference in New Issue
Block a user