[TF2XLA] Remove XlaTensor::set_host_tensor. It creates unnecessary complication
in the tf2xla bridge. If caching is truly needed, it can be maintained in the side datastructure. Extra copying should not justify complexity of the implementation: if extra copies are a concern, an op-by-op mode should not be used. PiperOrigin-RevId: 329816288 Change-Id: I80f8d94d23db81ae004b31e73e6f94b8cbc096f8
This commit is contained in:
parent
2cdb2b4d76
commit
eba3f769ec
tensorflow/compiler/jit
@ -113,14 +113,6 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
|
||||
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
|
||||
if (xla_tensor->has_host_tensor()) {
|
||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||
constant_arguments[i] = xla_tensor->host_tensor();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!constant_arguments.count(i)) {
|
||||
if (absl::c_binary_search(constant_input_indices, i)) {
|
||||
if (ctx->input_memory_type(i) != HOST_MEMORY) {
|
||||
|
@ -352,9 +352,6 @@ static Status SetOutputForConstant(
|
||||
ctx->set_output(output_num, const_tensor);
|
||||
output_tensor = ctx->mutable_output(output_num);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -71,18 +71,6 @@ class XlaTensor {
|
||||
shaped_buffer_ = std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
// Some tensors on the device may have known values on the host. We use these
|
||||
// in on-demand mode to avoid re-copying values from the device if we know the
|
||||
// host value already.
|
||||
|
||||
// Return true if this XlaTensor contains a host tensor.
|
||||
bool has_host_tensor() const { return host_tensor_.has_value(); }
|
||||
// Return the contained host tensor.
|
||||
// REQUIRES: has_host_tensor()
|
||||
const Tensor& host_tensor() const { return *host_tensor_; }
|
||||
// Sets the contained host tensor.
|
||||
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
|
||||
|
||||
// Adds synchronization events to 'stream' that wait for this tensor to be
|
||||
// defined on 'stream'. Does nothing if the tensor is already defined on that
|
||||
// stream.
|
||||
|
Loading…
Reference in New Issue
Block a user