[TF2XLA] Explicitly copy must-be-constants to host in get_compiler_ir
PiperOrigin-RevId: 334878548 Change-Id: Icfb2a4806d35c4974887aa19fc55b382a8a7ff5b
This commit is contained in:
parent
55cadd046b
commit
8991d372bf
@ -435,6 +435,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
|
|||||||
|
|
||||||
xla::StatusOr<std::string> GetCompilerIr(
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
absl::string_view func_name, Device* dev,
|
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||||
absl::Span<const Tensor* const> inputs) {
|
absl::Span<const TensorHandle* const> inputs_handles) {
|
||||||
NameAttrList function;
|
NameAttrList function;
|
||||||
function.set_name(std::string{func_name});
|
function.set_name(std::string{func_name});
|
||||||
|
|
||||||
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
|
|||||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||||
|
|
||||||
|
std::deque<Tensor> inputs_storage;
|
||||||
|
std::vector<const Tensor*> inputs;
|
||||||
|
inputs.reserve(inputs_handles.size());
|
||||||
|
for (int i = 0; i < inputs_handles.size(); i++) {
|
||||||
|
const TensorHandle* th = inputs_handles[i];
|
||||||
|
const Tensor* t;
|
||||||
|
// Handle owns the tensor.
|
||||||
|
TF_RETURN_IF_ERROR(th->Tensor(&t));
|
||||||
|
if (absl::c_binary_search(constant_arg_indices, i)) {
|
||||||
|
// Need to make sure it's on the host.
|
||||||
|
inputs_storage.emplace_back(t->dtype(), t->shape());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
|
||||||
|
inputs.push_back(&inputs_storage.back());
|
||||||
|
} else {
|
||||||
|
inputs.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<VariableInfo> variable_infos;
|
std::vector<VariableInfo> variable_infos;
|
||||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||||
|
@ -24,6 +24,8 @@ namespace tensorflow {
|
|||||||
class ProcessFunctionLibraryRuntime;
|
class ProcessFunctionLibraryRuntime;
|
||||||
class Device;
|
class Device;
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
class TensorHandle;
|
||||||
|
class EagerContext;
|
||||||
|
|
||||||
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
||||||
|
|
||||||
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
|||||||
// `runtime` on a device `dev` with given `inputs`.
|
// `runtime` on a device `dev` with given `inputs`.
|
||||||
xla::StatusOr<std::string> GetCompilerIr(
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
absl::string_view func_name, Device* dev,
|
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||||
absl::Span<const Tensor* const> inputs);
|
absl::Span<const TensorHandle* const> inputs);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -643,6 +643,19 @@ class DefFunctionTest(xla_test.XLATestCase):
|
|||||||
self.assertIn('tuple',
|
self.assertIn('tuple',
|
||||||
f.experimental_get_compiler_ir(l)())
|
f.experimental_get_compiler_ir(l)())
|
||||||
|
|
||||||
|
def testConstantOnWrongDevice(self):
|
||||||
|
with ops.device('device:{}:0'.format(self.device)):
|
||||||
|
|
||||||
|
s = math_ops.cast(random_ops.random_normal([2]), dtypes.int32)
|
||||||
|
l = random_ops.random_normal([s[0] * s[1]])
|
||||||
|
|
||||||
|
@def_function.function(experimental_compile=True)
|
||||||
|
def f(l):
|
||||||
|
return array_ops.reshape(l, s)
|
||||||
|
|
||||||
|
self.assertIn('tuple',
|
||||||
|
f.experimental_get_compiler_ir(l)())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -314,19 +314,10 @@ static std::string TFE_GetCompilerIr(py::handle& ctx,
|
|||||||
|
|
||||||
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
|
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
|
||||||
|
|
||||||
std::vector<const Tensor*> input_tensors;
|
std::vector<const TensorHandle*> input_handles;
|
||||||
for (TFE_TensorHandle* tensor_handle : handles) {
|
for (TFE_TensorHandle* tensor_handle : handles) {
|
||||||
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
|
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
|
||||||
TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle);
|
input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
|
||||||
|
|
||||||
const Tensor* t;
|
|
||||||
Status st = th->Tensor(&t);
|
|
||||||
if (!st.ok()) {
|
|
||||||
ThrowValueError(
|
|
||||||
absl::StrFormat("Could not resolve tensor: '%s'", st.error_message())
|
|
||||||
.c_str());
|
|
||||||
}
|
|
||||||
input_tensors.push_back(t);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceNameUtils::ParsedName input_device_name;
|
DeviceNameUtils::ParsedName input_device_name;
|
||||||
@ -347,7 +338,7 @@ static std::string TFE_GetCompilerIr(py::handle& ctx,
|
|||||||
|
|
||||||
xla::StatusOr<std::string> hlo_text =
|
xla::StatusOr<std::string> hlo_text =
|
||||||
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
|
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
|
||||||
*selected_device, input_tensors);
|
*selected_device, context, input_handles);
|
||||||
|
|
||||||
if (!hlo_text.ok()) {
|
if (!hlo_text.ok()) {
|
||||||
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
|
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user