George Karpenkov 53f2447911 [TF2XLA] [NFC] Allow using XlaCompileOnDemandOp without XlaDeviceMetadata
PiperOrigin-RevId: 325910208
Change-Id: I24f6b14fa24c614b0994ee2efdd077e5ef2fe55e
2020-08-10 16:27:05 -07:00

140 lines
4.8 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
#include <atomic>
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/stream_executor_util.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
namespace tensorflow {
// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
// The only difference is that it does not require arguments to follow
// the "constants, then regular args, then resources" order.
// It takes vectors of constant and resource arguments explicitly.
// It does not have corresponding OpDef because it is never present
// in the GraphDef.
// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
// this kernel when asked to create a kernel for an XLA-compiled function.
//
// `has_ref_vars`: whether the input computation can have reference variables.
// TODO(cheshire): instead derive this information from the input graph.
class XlaLocalLaunchBase : public OpKernel {
public:
XlaLocalLaunchBase(OpKernelConstruction* ctx,
const std::vector<int>& constants,
const std::vector<int>& resources,
const NameAttrList& function, bool has_ref_vars);
XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
~XlaLocalLaunchBase() override = default;
void Compute(OpKernelContext* ctx) override;
protected:
// Indexes of compile-time constant inputs
const std::vector<int> constants_;
// Indexes of resource inputs
const std::vector<int> resources_;
const NameAttrList function_;
const XlaPlatformInfo platform_info_;
bool has_ref_vars_;
};
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
// Once all inputs are present, and their shapes are known, the op can
// use a 'XlaCompilationCache' to compile and execute code which is specific
// to the shapes of input Tensors.
// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
// memory.
class XlaLocalLaunchOp : public XlaLocalLaunchBase {
public:
explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
~XlaLocalLaunchOp() override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
class XlaCompileOp : public OpKernel {
public:
explicit XlaCompileOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
// Indexes of compile-time constant inputs
const std::vector<int> constants_;
// Indexes of resource inputs
const std::vector<int> resources_;
const NameAttrList function_;
XlaPlatformInfo platform_info_;
const bool must_compile_;
// Whether the graph has TF reference variables.
const bool has_ref_vars_;
// cannot_compile_cluster_ is set to true if XLA returns an Unimplemented
// error when compiling the cluster this _XlaCompile is supposed to compile.
// If `cannot_compile_cluster_` is true then we avoid compiling this cluster
// on any future calls to _XlaCompile.
bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) =
false;
mutex cannot_compile_cluster_mu_;
};
class XlaRunOp : public OpKernel {
public:
explicit XlaRunOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
const XlaPlatformInfo platform_info_;
};
class XlaMergeOp : public OpKernel {
public:
explicit XlaMergeOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_