[TF2XLA] Remove XlaTensor::set_host_tensor. It creates unnecessary complication
in the tf2xla bridge. If caching is truly needed, it can be maintained in the side datastructure. Extra copying should not justify complexity of the implementation: if extra copies are a concern, an op-by-op mode should not be used. PiperOrigin-RevId: 329816288 Change-Id: I80f8d94d23db81ae004b31e73e6f94b8cbc096f8
This commit is contained in:
parent
2cdb2b4d76
commit
eba3f769ec
@ -113,14 +113,6 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
const Tensor& device_tensor = ctx->input(i);
|
const Tensor& device_tensor = ctx->input(i);
|
||||||
|
|
||||||
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
|
|
||||||
if (xla_tensor->has_host_tensor()) {
|
|
||||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
|
||||||
constant_arguments[i] = xla_tensor->host_tensor();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!constant_arguments.count(i)) {
|
if (!constant_arguments.count(i)) {
|
||||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||||
if (ctx->input_memory_type(i) != HOST_MEMORY) {
|
if (ctx->input_memory_type(i) != HOST_MEMORY) {
|
||||||
|
@ -352,9 +352,6 @@ static Status SetOutputForConstant(
|
|||||||
ctx->set_output(output_num, const_tensor);
|
ctx->set_output(output_num, const_tensor);
|
||||||
output_tensor = ctx->mutable_output(output_num);
|
output_tensor = ctx->mutable_output(output_num);
|
||||||
}
|
}
|
||||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
|
||||||
xla_tensor->set_host_tensor(const_tensor);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,18 +71,6 @@ class XlaTensor {
|
|||||||
shaped_buffer_ = std::move(shaped_buffer);
|
shaped_buffer_ = std::move(shaped_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some tensors on the device may have known values on the host. We use these
|
|
||||||
// in on-demand mode to avoid re-copying values from the device if we know the
|
|
||||||
// host value already.
|
|
||||||
|
|
||||||
// Return true if this XlaTensor contains a host tensor.
|
|
||||||
bool has_host_tensor() const { return host_tensor_.has_value(); }
|
|
||||||
// Return the contained host tensor.
|
|
||||||
// REQUIRES: has_host_tensor()
|
|
||||||
const Tensor& host_tensor() const { return *host_tensor_; }
|
|
||||||
// Sets the contained host tensor.
|
|
||||||
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
|
|
||||||
|
|
||||||
// Adds synchronization events to 'stream' that wait for this tensor to be
|
// Adds synchronization events to 'stream' that wait for this tensor to be
|
||||||
// defined on 'stream'. Does nothing if the tensor is already defined on that
|
// defined on 'stream'. Does nothing if the tensor is already defined on that
|
||||||
// stream.
|
// stream.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user