Merge branch 'master' into yunfeimao/matmul_tanh_fusion

This commit is contained in:
Mao Yunfei 2020-08-31 10:44:54 +08:00 committed by GitHub
commit 3f9d668ed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
512 changed files with 9838 additions and 4872 deletions

View File

@ -40,6 +40,22 @@ segfault_memory:
# assignees
filesystem_security_assignee:
- mihaimaruseac
tflite_micro_path:
- tensorflow/lite/micro
tflite_micro_comment: >
Thanks for contributing to TensorFlow Lite Micro.
To keep this process moving along, we'd like to make sure that you have completed the items on this list:
* Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
* Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
* Linked to the issue from the PR description
We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:

View File

@ -37,6 +37,9 @@
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type
`tf.complex64` or `tf.complex128`, because the behavior of these ops is not
well defined for complex types.
## Known Caveats
@ -120,6 +123,13 @@
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* The `steps_per_execution` argument in `compile()` is no longer
experimental; if you were passing `experimental_steps_per_execution`,
rename it to `steps_per_execution` in your code. This argument controls
the number of batches to run during each `tf.function` call when calling
`fit()`. Running multiple batches inside a single `tf.function` call can
greatly improve performance on TPUs or small models with a large Python
overhead.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
@ -147,6 +157,8 @@
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
* TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>

View File

@ -387,6 +387,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"//tensorflow/core/platform:blocking_counter",
"@com_google_absl//absl/strings",
],
alwayslink = 1,

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
@ -560,6 +561,21 @@ TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
collective_executor_handle->get()->StartAbort(status->status);
}
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, [&done, status](const Status& s) {
status->status = s;
done.Notify();
});
done.WaitForNotification();
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items;

View File

@ -238,6 +238,13 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(TFE_Context* ctx,
const char* task,
TF_Status* status);
// Information about the shape of a Tensor and its type.
struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank.

View File

@ -1704,66 +1704,5 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
TF_DeleteFunction(func1);
}
// This test only works when the TF build includes XLA compiler. One way to set
// this up is via bazel build option "--define with_xla_support=true".
//
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
// something like TENSORFLOW_CAPI_USE_XLA.
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
TF_Function* func;
const std::string funcName = "BranchFunc";
DefineFunction(funcName.c_str(), &func);
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* feed = Placeholder(host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_OperationDescription* desc =
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
TF_AddInput(desc, {true_cond, 0});
TF_Output inputs[] = {{feed, 0}};
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SetAttrType(desc, "Tcond", TF_BOOL);
TF_DataType inputType = TF_INT32;
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
TF_SetDevice(desc, "/device:XLA_CPU:0");
auto op = TF_FinishOperation(desc, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
ASSERT_NE(op, nullptr);
// Create a session for this graph.
CSession csession(host_graph_, s_, /*use_XLA*/ true);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run the graph.
csession.SetInputs({{feed, Int32Tensor(17)}});
csession.SetOutputs({op});
csession.Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
EXPECT_EQ(TF_INT32, TF_TensorType(out));
EXPECT_EQ(0, TF_NumDims(out)); // scalar
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
EXPECT_EQ(-17, *output_contents);
// Clean up
csession.CloseAndDelete(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_DeleteFunction(func);
}
#endif // TENSORFLOW_EAGER_USE_XLA
} // namespace
} // namespace tensorflow

View File

@ -6,7 +6,6 @@ load(
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
"tfe_xla_copts",
)
load(
"//tensorflow/core/platform:build_config.bzl",
@ -31,7 +30,7 @@ tf_cuda_library(
"c_api_unified_experimental.h",
],
hdrs = ["c_api.h"],
copts = tf_copts() + tfe_xla_copts(),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
@ -72,13 +71,6 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
],
}) + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/jit:xla_device",
],
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
@ -228,7 +220,6 @@ tf_cuda_cc_test(
"gradients_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
@ -278,6 +269,7 @@ cc_library(
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
@ -290,12 +282,9 @@ tf_cuda_cc_test(
"mnist_gradients_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"nomac",
"notap", # TODO(b/166150182): Enable
"no_oss", # TODO(b/166150182): Enable
],
deps = [
":abstract_tensor_handle",
@ -553,7 +542,6 @@ tf_cuda_cc_test(
"c_api_debug_test.cc",
"c_api_test.cc",
],
extra_copts = tfe_xla_copts(),
tags = [
"noguitar", # TODO(b/155445984): flaky
#"guitar",
@ -608,7 +596,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
@ -641,7 +628,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
@ -660,7 +646,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
@ -694,7 +679,6 @@ tf_cuda_cc_test(
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
@ -729,7 +713,7 @@ tf_cuda_library(
"c_api_experimental.h",
"c_api_unified_experimental.h",
],
copts = tf_copts() + tfe_xla_copts(),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
@ -801,7 +785,6 @@ tf_cuda_cc_test(
"c_api_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
@ -825,7 +808,6 @@ tf_cuda_cc_test(
"c_api_unified_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [

View File

@ -51,9 +51,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -1148,26 +1145,23 @@ void TFE_DeleteOp(TFE_Op* op) {
tensorflow::unwrap(op)->Release();
}
const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->Name().c_str();
}
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
return tensorflow::wrap(
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->DeviceName().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
if (!s.ok()) {
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
}
#else
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
}
@ -1180,6 +1174,15 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
static_cast<size_t>(num_inputs)});
}
extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
return tensorflow::unwrap(op)->GetInputs().size();
}
extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
TF_Status* status) {
return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT;
@ -1485,7 +1488,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) {
tensorflow::unwrap(ctx)->EndStep();
}
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
return tensorflow::wrap(
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
}
@ -1611,19 +1614,12 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
return status.status;
}
tensorflow::Status Execute(tensorflow::EagerOperation* op,
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> inputs;
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
@ -1633,10 +1629,6 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_DeleteTensorHandle(outputs[i]);
}
}
for (auto inp : inputs) {
TFE_DeleteTensorHandle(inp);
}
return status.status;
}

View File

@ -248,22 +248,22 @@ typedef struct TFE_Op TFE_Op;
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
// Returns the op or function name `op` will execute.
//
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
TF_Status* status);
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
TF_Status* status);
// When 'enable' is set to 1, and if TensorFlow library is built with XLA
// support, a subsequent TFE_Execute() call on `op` will run the op via XLA.
//
// If the library is not built with XLA support, this call would be a no-op.
TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op,
unsigned char enable);
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
TF_Status* status);
@ -272,6 +272,23 @@ TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op,
int num_inputs,
TF_Status* status);
// Fetches the current number of inputs attached to `op`.
//
// Does not use the operation's definition to determine how many inputs should
// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
// already-finalized operation.
//
// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
// particular named input list, which may only be part of the op's inputs).
TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
TF_Status* status);
// Returns a borrowed reference to one of `op`'s inputs. Use
// `TFE_TensorHandleCopySharingTensor` to make a new reference.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
int index,
TF_Status* status);
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
const char* attr_name,
unsigned char* is_list,

View File

@ -22,9 +22,6 @@ limitations under the License.
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/status.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/jit/xla_device.h"
#endif // TENSORFLOW_EAGER_USE_XLA
using tensorflow::string;
@ -64,87 +61,6 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
status->status = shape_fn(*tensor, &padded_shape);
if (!status->status.ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
std::vector<tensorflow::int64> shape_to_log =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
<< padded_shape.DebugString();
}
}
if (padded_shape.IsTuple()) {
if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) {
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
}
// shape0 is not a const& because we will assign it to padded_shape below.
// It is illegal to assign a part of a message to itself.
xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0);
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
status->status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
}
// Since the only case we handle here are two equal subshapes, we
// simply return one of them. The caller will interpret it as this
// shape directly storing the 64bit types. This approximation is good
// enough for this API's debugging use case.
padded_shape = shape0;
}
int rank = padded_shape.dimensions_size();
std::vector<tensorflow::int64> dev_dims;
dev_dims.reserve(rank);
if (rank == 1) {
// Rank 1 tensors might not have padded_shape.layout.minor_to_major set,
dev_dims.push_back(padded_shape.dimensions(0));
} else {
for (int i = rank - 1; i >= 0; --i) {
tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i);
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
status->status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
std::vector<tensorflow::int64> dev_dims =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {

View File

@ -414,7 +414,7 @@ typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op);
TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
@ -435,7 +435,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
size_t proto_len,
TF_Status* status);
#define TFE_CUSTOM_DEVICE_VERSION 2
// TODO(b/166642410): It would be nice, for custom devices and for other users,
// to have a non-string representation of devices (TF_Device) extracted from
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
#define TFE_CUSTOM_DEVICE_VERSION 3
// Struct to be filled in
typedef struct TFE_CustomDevice {
@ -454,9 +458,16 @@ typedef struct TFE_CustomDevice {
void* device_info);
// Method to execute an operation.
void (*execute)(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
//
// Arguments provide enough information to reconstruct the original `TFE_Op`,
// or construct a transformed version, by inspecting the passed `op`.
//
// TFE_OpGetDevice(op) records the original placement of the operation. It may
// be an empty string if no device was explicitly requested, but will
// otherwise be the name of this custom device. Ops are placed onto a custom
// device if any of their inputs are on that custom device, but custom devices
// are free to set a bad status in order to require explicit placement.
void (*execute)(const TFE_Op* op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.

View File

@ -316,86 +316,6 @@ TEST(CAPI, Function_ident_CPU) {
TF_DeleteStatus(status);
}
#ifdef TENSORFLOW_EAGER_USE_XLA
TEST(CAPI, Function_ident_XLA_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
TF_OperationDescription* arg_descr =
TF_NewOperation(function_graph, "Placeholder", "arg");
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
TF_Status* status = TF_NewStatus();
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_OperationDescription* id_descr =
TF_NewOperation(function_graph, "Identity", "id");
TF_SetAttrType(id_descr, "T", TF_INT32);
TF_AddInput(id_descr, {arg, 0});
TF_Operation* id = TF_FinishOperation(id_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_Output input{arg, 0};
TF_Output output{id, 0};
TF_Function* fn =
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
&output, nullptr, nullptr, "test", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteGraph(function_graph);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextAddFunction(ctx, fn, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
for (bool async : {false, true, false}) {
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK);
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteTensor(t);
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_OpAddInput(op, h, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
// Now run it via XLA.
TFE_OpSetXLACompilation(op, true);
std::vector<TFE_TensorHandle*> result;
result.push_back(nullptr);
int num_retvals = 1;
TFE_Execute(op, result.data(), &num_retvals, status);
TFE_DeleteOp(op);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
ASSERT_EQ(num_retvals, 1);
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
TFE_ContextSetExecutorForThread(ctx, old_executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_DeleteExecutor(old_executor);
TFE_DeleteTensorHandle(h);
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
#endif // TENSORFLOW_EAGER_USE_XLA
void Executor_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -876,89 +876,6 @@ TEST(CAPI, Execute_Min_CPU) {
TF_DeleteStatus(status);
}
#ifdef TENSORFLOW_EAGER_USE_XLA
void Execute_MatMul_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_OpSetXLACompilation(matmul, true);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
// Running a primitive TF operator via XLA is not yet supported.
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); }
TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); }
void Execute_Min_XLA_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
TFE_Op* minOp = MinOp(ctx, input, axis);
TFE_OpSetXLACompilation(minOp, true);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(minOp, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(minOp);
TFE_DeleteTensorHandle(input);
TFE_DeleteTensorHandle(axis);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float output[2] = {0};
EXPECT_EQ(sizeof(output), TF_TensorByteSize(t));
memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, output[0]);
EXPECT_EQ(3, output[1]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); }
TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); }
#endif // TENSORFLOW_EAGER_USE_XLA
void ExecuteWithTracing(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -1620,4 +1537,91 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
TFE_DeleteContext(ctx);
}
// Needs to work with a const TFE_Op since custom devices should not modify the
// op they are called with.
TFE_Op* CloneOp(const TFE_Op* other) {
TF_Status* status = TF_NewStatus();
TFE_Context* context = TFE_OpGetContext(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char* op_name = TFE_OpGetName(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* ret = TFE_NewOp(context, op_name, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char* device = TFE_OpGetDevice(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetDevice(ret, device, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddAttrs(ret, TFE_OpGetAttrs(other));
int num_inputs = TFE_OpGetFlatInputCount(other, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
for (int input_index = 0; input_index < num_inputs; ++input_index) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(other, input_index, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(ret, input, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TF_DeleteStatus(status);
return ret;
}
TEST(CAPI, TestTFE_OpRecreation) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Clone an op with attributes and a device set.
TFE_Op* original_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(original_var_op, "dtype", TF_INT64);
TFE_OpSetAttrShape(original_var_op, "shape", {}, 0, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("", std::string(TFE_OpGetDevice(original_var_op, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetDevice(original_var_op,
"/job:localhost/replica:0/task:0/device:CPU:0", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* cloned = CloneOp(original_var_op);
EXPECT_EQ("/job:localhost/replica:0/task:0/device:CPU:0",
std::string(TFE_OpGetDevice(cloned, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
EXPECT_EQ("VarHandleOp", std::string(TFE_OpGetName(cloned, status)));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_TensorHandle* ret;
TFE_Execute(cloned, &ret, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(ret);
// Clone an op with inputs and no device set.
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
TFE_Op* original_identity = TFE_NewOp(ctx, "IdentityN", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* inputs[] = {input1, input2};
TFE_OpAddInputList(original_identity, inputs, 2, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* cloned_identity = CloneOp(original_identity);
EXPECT_EQ("", std::string(TFE_OpGetDevice(cloned_identity, status)));
TFE_TensorHandle* identity_ret[] = {nullptr, nullptr};
num_retvals = 2;
TFE_Execute(cloned_identity, identity_ret, &num_retvals, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(input1);
TFE_DeleteTensorHandle(input2);
TFE_DeleteTensorHandle(identity_ret[0]);
TFE_DeleteTensorHandle(identity_ret[1]);
TFE_DeleteOp(cloned_identity);
TFE_DeleteOp(original_identity);
TFE_DeleteOp(original_var_op);
TFE_DeleteOp(cloned);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
} // namespace

View File

@ -36,7 +36,8 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived);
@ -73,7 +74,8 @@ TEST(CUSTOM_DEVICE, ResetOperation) {
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom_device_name,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -103,7 +105,8 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
@ -187,7 +190,8 @@ TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
@ -264,10 +268,12 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom0,
/*strict_scope_placement=*/false, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
RegisterLoggingDevice(context.get(), custom1,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
@ -314,14 +320,34 @@ TEST(CUSTOM_DEVICE, InputBasedPlacement) {
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical fails.
// Custom device: mix of custom/physical places the op on the custom device.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
EXPECT_TRUE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Explicit placement still forces the op onto the requested device
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
// configured to do)
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
@ -334,21 +360,24 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0", &arrived, &executed,
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(context.get(),
"/job:localhost/replica:0/task:0/device:CPU:0",
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}

View File

@ -33,6 +33,9 @@ struct LoggingDevice {
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
// If true, only explicit op placements are accepted. If false, uses
// type-based dispatch.
bool strict_scope_placement;
};
struct LoggedTensor {
@ -84,18 +87,35 @@ TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
return nullptr;
}
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
if (dev->strict_scope_placement && *requested_placement == '\0') {
TF_SetStatus(s, TF_INTERNAL,
"Ops must be placed on the device explicitly, or their inputs "
"first copied to other devices.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
if (TF_GetCode(s) != TF_OK) return;
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
@ -131,8 +151,8 @@ void DeleteLoggingDevice(void* device_info) {
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status) {
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
@ -143,6 +163,7 @@ void RegisterLoggingDevice(TFE_Context* context, const char* name,
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
device->strict_scope_placement = strict_scope_placement;
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
@ -168,5 +189,6 @@ void AllocateLoggingDevice(const char* name, bool* arrived_flag,
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
logging_device->strict_scope_placement = true;
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag, bool* executed_flag,
TF_Status* status);
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);

View File

@ -242,6 +242,7 @@ namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op;
forward_op_->attrs.Reset(op);
return op_->Reset(op, raw_device_name);
}
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
@ -418,6 +419,11 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
}
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
// attributes. Number type attrs and DataType attrs work fine without this.
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_->attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));

View File

@ -507,6 +507,57 @@ TEST_P(CppGradients, TestIdentityNGrad) {
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr t;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
t.reset(x_raw);
}
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx.get();
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
/*raw_device_name=*/nullptr, &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
if (isa<TracingOperation>(check_numerics_op.get())) {
s = dyn_cast<TracingOperation>(check_numerics_op.get())
->SetOpName("check_numerics");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string message = "This is the way!";
s = SetAttrString(check_numerics_op.get(), "message", message.data(),
message.length(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry;
std::unique_ptr<Tape> tape(new Tape(/*persistent=*/false));
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string read_message;
s = forward_op.attrs.Get("message", &read_message);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_EQ(read_message, message);
}
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE

View File

@ -47,9 +47,6 @@ class ImmediateExecutionOperation : public AbstractOperation {
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;

View File

@ -765,13 +765,13 @@ TEST_P(CppGradients, TestMNIST_Training) {
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif

View File

@ -31,11 +31,15 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
using std::vector;
using tracing::TracingOperation;
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
@ -272,6 +276,7 @@ Status MNISTForwardModel(AbstractContext* ctx,
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
@ -592,3 +597,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
TFE_DeleteContextOptions(opts);
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -27,13 +27,13 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
using namespace tensorflow;
using namespace tensorflow::gradients;
using namespace tensorflow::gradients::internal;
#include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
@ -144,3 +144,7 @@ Status RunModel(Model model, AbstractContext* ctx,
const GradientRegistry& registry);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -255,28 +255,44 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, status);
if (*requested_placement == '\0') {
TF_SetStatus(
status, TF_INTERNAL,
"Ops must be placed on the parallel device explicitly, or their inputs "
"first un-packed. Got an un-placed op with an input placed on the "
"parallel device.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
if (TF_GetCode(status) != TF_OK) return;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
if (TF_GetCode(status) != TF_OK) return;
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
TFE_TensorHandleDevicePointer(input, status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
typed_inputs.emplace_back(input);
}
}

View File

@ -29,6 +29,7 @@ cc_library(
":gcs_helper",
":ram_file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
@ -134,6 +135,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
}
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
TF_SetStatus(status, TF_OK, "");
TF_VLog(1, "Successful read of %s @ %u of size: %u", path.c_str(), offset,
read);
stream.read(buffer, read);
read = stream.gcount();
if (read < buffer_size) {
@ -146,6 +149,8 @@ static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
path, " @ ", offset)
.c_str());
}
TF_VLog(2, "Successful integrity check for: %s @ %u", path.c_str(),
offset);
}
}
return read;
@ -284,6 +289,8 @@ static void SyncImpl(const std::string& bucket, const std::string& object,
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
TF_VLog(3, "AppendObject: gs://%s/%s to gs://%s/%s", bucket.c_str(),
temporary_object.c_str(), bucket.c_str(), object.c_str());
const std::vector<gcs::ComposeSourceObject> source_objects = {
{object, {}, {}}, {temporary_object, {}, {}}};
metadata = gcs_client->ComposeObject(bucket, source_objects, object);
@ -321,6 +328,8 @@ void Append(const TF_WritableFile* file, const char* buffer, size_t n,
"The internal temporary file is not writable.");
return;
}
TF_VLog(3, "Append: gs://%s/%s size %u", gcs_file->bucket.c_str(),
gcs_file->object.c_str(), n);
gcs_file->sync_need = true;
gcs_file->outfile.write(buffer, n);
if (!gcs_file->outfile)
@ -346,6 +355,8 @@ int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
void Flush(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->sync_need) {
TF_VLog(3, "Flush started: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (!gcs_file->outfile) {
TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file.");
@ -353,6 +364,8 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
}
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
&gcs_file->outfile, gcs_file->gcs_client, status);
TF_VLog(3, "Flush finished: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (TF_GetCode(status) != TF_OK) return;
gcs_file->sync_need = false;
} else {
@ -361,11 +374,16 @@ void Flush(const TF_WritableFile* file, TF_Status* status) {
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
TF_VLog(3, "Sync: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
Flush(file, status);
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
TF_VLog(3, "Close: gs://%s/%s", gcs_file->bucket.c_str(),
gcs_file->object.c_str());
if (gcs_file->sync_need) {
Flush(file, status);
}
@ -428,6 +446,8 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
max_staleness = value;
}
TF_VLog(1, "GCS cache max size = %u ; block size = %u ; max staleness = %u",
max_bytes, block_size, max_staleness);
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
@ -511,6 +531,10 @@ static void UncachedStatForObject(const std::string& bucket,
stat->base.mtime_nsec =
metadata->time_storage_class_updated().time_since_epoch().count();
stat->base.is_directory = object.back() == '/';
TF_VLog(1,
"Stat of: gs://%s/%s -- length: %u generation: %u; mtime_nsec: %u;",
bucket.c_str(), object.c_str(), stat->base.length,
stat->generation_number, stat->base.mtime_nsec);
return TF_SetStatus(status, TF_OK, "");
}
@ -545,9 +569,10 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return -1;
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
path, stat.generation_number)) {
std::cout
<< "File signature has been changed. Refreshing the cache. Path: "
<< path;
TF_VLog(
1,
"File signature has been changed. Refreshing the cache. Path: %s",
path.c_str());
}
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else {
@ -579,6 +604,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
(gcs_file->compose ? 0 : -1)});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
TF_VLog(3, "GcsWritableFile: %s", path);
TF_SetStatus(status, TF_OK, "");
}
@ -624,7 +650,8 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
return;
}
}
TF_VLog(3, "GcsWritableFile: %s with existing file %s", path,
temp_file_name.c_str());
TF_SetStatus(status, TF_OK, "");
}
@ -812,6 +839,10 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string dir = path;
MaybeAppendSlash(&dir);
TF_VLog(3,
"CreateDir: creating directory with path: %s and "
"path_with_slash: %s",
path, dir.c_str());
std::string bucket, object;
ParseGCSPath(dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
@ -826,8 +857,11 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
}
PathExists(filesystem, dir.c_str(), status);
if (TF_GetCode(status) == TF_OK)
if (TF_GetCode(status) == TF_OK) {
// Use the original name for a correct error here.
TF_VLog(3, "CreateDir: directory already exists, not uploading %s", path);
return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
auto metadata = gcs_file->gcs_client.InsertObject(
bucket, object, "",
@ -933,6 +967,7 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst,
TF_Status* status) {
TF_VLog(3, "RenameObject: started %s to %s", src.c_str(), dst.c_str());
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
@ -946,6 +981,7 @@ static void RenameObject(const TF_Filesystem* filesystem,
bucket_src, object_src, bucket_dst, object_dst);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return;
TF_VLog(3, "RenameObject: finished %s to %s", src.c_str(), dst.c_str());
ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status);

View File

@ -43,8 +43,8 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = default;
// This method returns the "Call" Op used to execute the function.
virtual Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) = 0;
virtual Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const = 0;
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
};

View File

@ -28,6 +28,26 @@ cc_library(
],
)
cc_library(
name = "flat_tensor_function",
srcs = [
"flat_tensor_function.cc",
],
hdrs = [
"flat_tensor_function.h",
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "variable",
srcs = [
@ -68,7 +88,7 @@ cc_library(
"tf_concrete_function.h",
],
deps = [
":tensorhandle_convertible",
":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
@ -81,3 +101,26 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_signature_def_function",
srcs = [
"tf_signature_def_function.cc",
],
hdrs = [
"tf_signature_def_function.h",
],
deps = [
":flat_tensor_function",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include <memory>
#include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
FlatTensorFunction::FlatTensorFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx)
: name_(name), captures_(std::move(captures)), ctx_(ctx) {}
FlatTensorFunction::~FlatTensorFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status FlatTensorFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx, std::unique_ptr<FlatTensorFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new FlatTensorFunction(function_def->signature().name(),
std::move(captures), ctx));
return Status();
}
Status FlatTensorFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle* const*>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// FlatTensorFunction models a TF2 eager runtime view of a callable function,
// taking + returning flat lists of tensors, including any captures.
// Effectively, it is a thin wrapper around a FunctionDef owned by the
// EagerContext, and any TensorHandle captures associated with the function. The
// MakeCallOp method handles the logic of marshaling captures after the user
// provided inputs automatically.
// Note(bmzhao): This class is mainly intended to house low-level reusable
// function logic between SignatureDefFunction and ConcreteFunction, which
// present higher level interfaces. This type does *not* hold any "function
// metadata".
class FlatTensorFunction {
public:
// Factory for creating a FlatTensorFunction.
//
// Params:
// function_def - The function_def associated with the created
// FlatTensorFunction. FlatTensorFunction will register this
// function_def with `ctx` on creation, and de-register it on
// destruction. function_def must be non-null, but
// otherwise has no lifetime requirements.
// captures - The captured TensorHandles associated with this
// FlatTensorFunction.
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
// outlive TFConcreteFunction.
// out - The output FlatTensorFunction.
static Status Create(const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx,
std::unique_ptr<FlatTensorFunction>* out);
// This method creates a "Call" Op used to execute the function.
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const;
~FlatTensorFunction();
private:
FlatTensorFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
ImmediateExecutionContext* ctx);
FlatTensorFunction(const FlatTensorFunction&) = delete;
FlatTensorFunction& operator=(const FlatTensorFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
@ -33,32 +33,20 @@ limitations under the License.
namespace tensorflow {
TFConcreteFunction::TFConcreteFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
: name_(name),
captures_(std::move(captures)),
metadata_(std::move(metadata)),
ctx_(ctx) {}
TFConcreteFunction::~TFConcreteFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
TFConcreteFunction::TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
FunctionMetadata metadata)
: func_(std::move(func)), metadata_(std::move(metadata)) {}
Status TFConcreteFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new TFConcreteFunction(function_def->signature().name(),
std::move(captures), std::move(metadata),
ctx));
std::unique_ptr<FlatTensorFunction> func;
TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
function_def, std::move(captures), ctx, &func));
out->reset(new TFConcreteFunction(std::move(func), std::move(metadata)));
return Status();
}
@ -66,30 +54,9 @@ const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
return metadata_;
}
Status TFConcreteFunction::GetCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
// Adding the user-provided inputs to the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(inputs));
absl::Span<AbstractTensorHandle* const> captures(
reinterpret_cast<AbstractTensorHandle**>(captures_.data()),
captures_.size());
// Adding the captures of the function.
TF_RETURN_IF_ERROR((*out)->AddInputList(captures));
return Status();
Status TFConcreteFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
return func_->MakeCallOp(inputs, out);
}
} // namespace tensorflow

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
@ -58,26 +58,22 @@ class TFConcreteFunction : public ConcreteFunction {
std::unique_ptr<TFConcreteFunction>* out);
// This method returns the "Call" Op used to execute the function.
Status GetCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) override;
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const override;
const FunctionMetadata& GetFunctionMetadata() const override;
~TFConcreteFunction() override;
~TFConcreteFunction() override = default;
private:
TFConcreteFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
TFConcreteFunction(std::unique_ptr<FlatTensorFunction> func,
FunctionMetadata metadata);
TFConcreteFunction(const TFConcreteFunction&) = delete;
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
std::unique_ptr<FlatTensorFunction> func_;
FunctionMetadata metadata_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include <memory>
#include <string>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
TFSignatureDefFunction::TFSignatureDefFunction(
std::unique_ptr<FlatTensorFunction> func,
SignatureDefFunctionMetadata metadata)
: func_(std::move(func)), metadata_(std::move(metadata)) {}
Status TFSignatureDefFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFSignatureDefFunction>* out) {
std::unique_ptr<FlatTensorFunction> func;
TF_RETURN_IF_ERROR(FlatTensorFunction::Create(
function_def, std::move(captures), ctx, &func));
out->reset(new TFSignatureDefFunction(std::move(func), std::move(metadata)));
return Status();
}
const SignatureDefFunctionMetadata&
TFSignatureDefFunction::GetFunctionMetadata() const {
return metadata_;
}
Status TFSignatureDefFunction::MakeCallOp(
absl::Span<AbstractTensorHandle* const> inputs, ImmediateOpPtr* out) const {
return func_->MakeCallOp(inputs, out);
}
} // namespace tensorflow

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// This is the TF eager runtime implementation of SignatureDefFunction (separate
// from the TFRT implementation). The user-facing API of SignatureDefFunctions
// and their semantic differences from ConcreteFunction are described here:
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59
// Additional implementation notes are available here:
// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48
class TFSignatureDefFunction : public SignatureDefFunction {
public:
// Factory function for creating a TFSignatureDefFunction.
//
// Params:
// function_def - The function_def associated with the created
// TFSignatureDefFunction. TFSignatureDefFunction will
// register this function_def with `ctx` on creation, and
// de-register it on destruction. function_def must be
// non-null, but otherwise has no lifetime requirements.
// captures - The captured TensorHandles associated with this
// TFConcreteFunction.
// metadata - FunctionMetadata associated with this TFSignatureDefFunction.
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
// outlive TFSignatureDefFunction.
// out - The output TFSignatureDefFunction.
static Status Create(const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
SignatureDefFunctionMetadata metadata,
ImmediateExecutionContext* ctx,
std::unique_ptr<TFSignatureDefFunction>* out);
// This method creates a "Call" Op used to execute the function.
Status MakeCallOp(absl::Span<AbstractTensorHandle* const> inputs,
ImmediateOpPtr* out) const override;
const SignatureDefFunctionMetadata& GetFunctionMetadata() const override;
~TFSignatureDefFunction() override = default;
private:
TFSignatureDefFunction(std::unique_ptr<FlatTensorFunction> func,
SignatureDefFunctionMetadata metadata);
TFSignatureDefFunction(const TFSignatureDefFunction&) = delete;
TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete;
std::unique_ptr<FlatTensorFunction> func_;
SignatureDefFunctionMetadata metadata_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SIGNATURE_DEF_FUNCTION_H_

View File

@ -34,15 +34,15 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
&tensorflow::unwrap(func)->GetFunctionMetadata()));
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
TFE_Op* TF_ConcreteFunctionMakeCallOp(TF_ConcreteFunction* func,
TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
tensorflow::ImmediateOpPtr call_op;
absl::Span<tensorflow::AbstractTensorHandle* const> input_span(
reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs));
status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op);
status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op);
if (!status->status.ok()) {
return nullptr;
}

View File

@ -107,7 +107,7 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b);
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(
TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -47,7 +47,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
// high-level API here. A strawman for what this interface could look like:
// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value*
// inputs, int num_inputs, TF_Status* status);
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp(
TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status);

View File

@ -132,6 +132,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "summary_op_benchmark_test",
size = "small",
srcs = ["summary_op_benchmark_test.cc"],
deps = [
":summary_op",
"//tensorflow/c:kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tensor_shape_utils",
srcs = ["tensor_shape_utils.cc"],

View File

@ -93,11 +93,13 @@ void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
std::ostringstream err;
err << "Nan in summary histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, status.get());
return;
} else if (Eigen::numext::isinf(double_val)) {
std::ostringstream err;
err << "Infinity in Histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
histo.Add(double_val);

View File

@ -0,0 +1,71 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) {
Graph* g = new Graph(OpRegistry::Global());
Tensor tags(DT_STRING, shape);
Tensor values(DT_FLOAT, shape);
for (int i = 0; i < tags.NumElements(); ++i) {
tags.flat<tstring>()(i) = tag;
values.flat<float>()(i) = value;
}
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary")
.Input(test::graph::Constant(g, tags))
.Input(test::graph::Constant(g, values))
.Attr("T", DT_FLOAT)
.Finalize(g, &ret));
return g;
}
// Macro used to parse initializer list for tensorshape
#define DIMARGS(...) \
{ __VA_ARGS__ }
// // Random parameters for testing
constexpr char longTagParam[] = "LONGTAG____________________________";
constexpr float largeValueParam = 2352352.2623433;
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
void BM_ScalarSummary##name##device(int iters) { \
testing::StopTiming(); \
TensorShape tensorshape(DIMARGS dims); \
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
testing::StartTiming(); \
test::Benchmark("cpu", g).Run(iters); \
} \
BENCHMARK(BM_ScalarSummary##name##device);
BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);
// Benchmark for large shapes
BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2);
// Benchmark for large tag tstring
BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2);
// Benchmark for large values
BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam);
} // namespace
} // namespace tensorflow

View File

@ -329,6 +329,7 @@ cc_library(
srcs = ["xla_compilation_cache.cc"],
hdrs = ["xla_compilation_cache.h"],
deps = [
":flags",
":xla_activity_listener",
":xla_activity_proto_cc",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
@ -361,8 +362,11 @@ tf_cc_test(
"xla_compilation_cache_test.cc",
],
deps = [
":flags",
":xla_compilation_cache",
":xla_cpu_jit",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
@ -918,6 +922,7 @@ tf_cc_test(
":xla_cpu_jit",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/compiler/tf2xla:test_util",

View File

@ -518,10 +518,15 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
}
}
// Returns `true` iff node has a given `attr` set to `true`. Returns `false`
// both for the missing attr, and the attr set to `false`.
static bool HasBoolAttr(const NodeDef& node, const char* attr) {
const auto& it = node.attr().find(attr);
return it != node.attr().end() && it->second.b();
}
bool CanCreateXlaKernel(const NodeDef& node_def) {
// If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
return it != node_def.attr().end() && it->second.b();
return HasBoolAttr(node_def, kXlaMustCompileAttr);
}
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
@ -564,4 +569,58 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
return Status::OK();
}
static auto const ops_triggering_xla_compilation =
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
"XlaConv",
"XlaDequantize",
"XlaDot",
"XlaDynamicSlice",
"XlaDynamicUpdateSlice",
"XlaEinsum",
"XlaGather",
"XlaIf",
"XlaKeyValueSort",
"XlaPad",
"XlaRecv",
"XlaReduce",
"XlaReduceWindow",
"XlaReplicaId",
"XlaScatter",
"XlaSelectAndScatter",
"XlaSelfAdjointEig",
"XlaSend",
"XlaSharding",
"XlaSort",
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaWhile"};
static bool NodeCanTriggerXlaCompilation(const NodeDef& node) {
return node.attr().find(kXlaClusterIdAttr) != node.attr().end() ||
HasBoolAttr(node, kXlaMustCompileAttr) ||
HasBoolAttr(node, kXlaCompileAttr) ||
HasBoolAttr(node, kXlaScopeAttr) ||
HasBoolAttr(node, kXlaInternalScopeAttr) ||
ops_triggering_xla_compilation->count(node.op());
}
bool CanTriggerXlaCompilation(const GraphDef& graph) {
for (const FunctionDef& function : graph.library().function()) {
for (const NodeDef& node : function.node_def()) {
if (NodeCanTriggerXlaCompilation(node)) {
return true;
}
}
}
for (const NodeDef& node : graph.node()) {
if (NodeCanTriggerXlaCompilation(node)) {
return true;
}
}
return false;
}
} // namespace tensorflow

View File

@ -126,9 +126,10 @@ class RecursiveCompilabilityChecker {
bool allow_inaccurate_ops = false;
};
RecursiveCompilabilityChecker(const OperationFilter* op_filter,
const DeviceType* jit_device_type)
: op_filter_(*op_filter), jit_device_type_(*jit_device_type) {}
RecursiveCompilabilityChecker(OperationFilter op_filter,
DeviceType jit_device_type)
: op_filter_(std::move(op_filter)),
jit_device_type_(std::move(jit_device_type)) {}
using UncompilableNodesMap =
std::map<std::string,
@ -259,8 +260,8 @@ class RecursiveCompilabilityChecker {
// Make sure we don't recurse infinitely on recursive functions.
const size_t kMaxRecursionDepth = 10;
const OperationFilter& op_filter_;
const DeviceType& jit_device_type_;
const OperationFilter op_filter_;
const DeviceType jit_device_type_;
};
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
@ -282,6 +283,9 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
// Check whether graph can trigger XLA compilation.
bool CanTriggerXlaCompilation(const GraphDef& graph);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -75,8 +76,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
op_filter_.allow_inaccurate_ops = false;
op_filter_.allow_slow_ops = false;
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(&op_filter_,
&device_type_);
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
device_type_);
}
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
@ -354,5 +355,110 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
"unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
FunctionDef identity_func = FunctionDefHelper::Create(
"IdentityFunc",
/*in_def=*/{"x:float"},
/*out_def=*/{"res:float"},
/*attr_def=*/{},
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def*/ {{"res", "t0:output"}});
*library.add_function() = identity_func;
Output in = ops::Placeholder(root, DT_FLOAT);
NameAttrList b_name_attr;
b_name_attr.set_name("IdentityFunc");
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
b_name_attr);
GraphDef graph_def;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_FALSE(CanTriggerXlaCompilation(graph_def));
}
TEST_F(CompilabilityCheckUtilTest, TestXlaOpsCanTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
FunctionDef sort_func = FunctionDefHelper::Create(
"SortFunc",
/*in_def=*/{"x:float"},
/*out_def=*/{"res:float"},
/*attr_def=*/{},
/*node_def=*/{{{"t0"}, "XlaSort", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def*/ {{"res", "t0:output"}});
*library.add_function() = sort_func;
Output in = ops::Placeholder(root, DT_FLOAT);
NameAttrList b_name_attr;
b_name_attr.set_name("SortFunc");
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
b_name_attr);
GraphDef graph_def;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
}
TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
AttrValue true_attribute;
true_attribute.set_b(true);
FunctionDef identity_func = FunctionDefHelper::Create(
"IdentityFunc",
/*in_def=*/{"x:float"},
/*out_def=*/{"res:float"},
/*attr_def=*/{},
/*node_def=*/{{{"t0"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def*/ {{"res", "t0:output"}});
(*identity_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
FunctionDef call_identity = FunctionDefHelper::Create(
"CallIdentity",
/*in_def=*/{"x:float"},
/*out_def=*/{"z:float"}, /*attr_def=*/{},
/*node_def=*/
{{{"func_call"},
"PartitionedCall",
{"x"},
{{"Tin", DataTypeSlice({DT_FLOAT})},
{"Tout", DataTypeSlice({DT_FLOAT})},
{"f",
FunctionDefHelper::FunctionRef("IdentityRef", {{"T", DT_FLOAT}})},
{kXlaMustCompileAttr, true}}}},
/*ret_def=*/{{"z", "func_call:output:0"}});
*library.add_function() = identity_func;
*library.add_function() = call_identity;
Output in = ops::Placeholder(root, DT_FLOAT);
NameAttrList b_name_attr;
b_name_attr.set_name("CallIdentity");
ops::PartitionedCall call(root.WithOpName("call"), {in}, {DT_FLOAT},
b_name_attr);
GraphDef graph_def;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
EXPECT_TRUE(CanTriggerXlaCompilation(graph_def));
}
} // namespace
} // namespace tensorflow

View File

@ -28,4 +28,6 @@ const char* const kXlaScopeAttr = "_XlaScope";
// only when auto_jit is ON.
const char* const kXlaInternalScopeAttr = "_XlaInternalScope";
const char* const kXlaClusterIdAttr = "_xla_compile_id";
} // namespace tensorflow

View File

@ -35,6 +35,9 @@ extern const char* const kXlaCompileAttr; // "_XlaCompile"
extern const char* const kXlaScopeAttr; // "_XlaScope"
extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope"
// The id of the compiled cluster.
extern const char* const kXlaClusterIdAttr; // "_xla_compile_id"
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -34,9 +35,6 @@ limitations under the License.
namespace tensorflow {
const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
"_xla_compile_id";
namespace {
const char* const kXlaClusterOutput = "XlaClusterOutput";
@ -45,10 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) {
for (Node* n : graph->nodes()) {
string name;
// Only consider nodes being compiled.
if (!GetNodeAttr(n->attrs(),
EncapsulateXlaComputationsPass::kXlaClusterAttr, &name)
.ok())
continue;
if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue;
// Early return for any node with a device that is not a CPU or GPU.
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) {
@ -180,8 +175,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
retvals[i]->AddAttr("index", i);
}
AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
call_def);
AddNodeAttr(kXlaClusterIdAttr, call_def->name(), call_def);
AddNodeAttr("_variable_start_index", variable_start_index, call_def);
// Uniquify the function name.
@ -216,8 +210,8 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// O(n) pass over the edges.
for (const Edge* e : (*graph)->edges()) {
if (!e->IsControlEdge() &&
e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->src()->attrs().Find(kXlaClusterIdAttr) != nullptr &&
e->dst()->attrs().Find(kXlaClusterIdAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument(
"Undeclared output of XLA computation. Some common causes of this "
@ -232,9 +226,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
auto output = absl::make_unique<Graph>((*graph)->op_registry());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
EncapsulateSubgraphsInFunctions(kXlaClusterAttr, **graph, RewriteSubgraph,
/*reuse_existing_functions=*/true,
&output, flib_def),
EncapsulateSubgraphsInFunctions(
kXlaClusterIdAttr, **graph, RewriteSubgraph,
/*reuse_existing_functions=*/true, &output, flib_def),
"EncapsulateXlaComputationsPass failed");
graph->swap(output);
return Status::OK();
@ -246,7 +240,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// while iterating.
std::vector<Node*> launch_nodes;
for (Node* n : graph->nodes()) {
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr);
const string& name = GetNodeAttrString(n->attrs(), kXlaClusterIdAttr);
if (!name.empty()) {
launch_nodes.push_back(n);
}

View File

@ -34,8 +34,6 @@ namespace tensorflow {
// XlaLaunch operators.
class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
public:
static const char* const kXlaClusterAttr; // _xla_compile_id
Status Run(const GraphOptimizationPassOptions& options) override;
// The following methods are public only for unit tests.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
@ -46,19 +47,18 @@ static std::unique_ptr<Graph> MakeOuterGraph(
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
NodeDef def;
TF_CHECK_OK(
NodeDefBuilder("launch0", function, &flib_def)
.Input(a.node()->name(), 0, DT_INT32)
.Input(b.node()->name(), 0, DT_FLOAT)
.Input(c.node()->name(), 0, DT_INT32)
.Input(d.node()->name(), 0, DT_FLOAT)
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
.Device("/gpu:0")
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
TF_CHECK_OK(NodeDefBuilder("launch0", function, &flib_def)
.Input(a.node()->name(), 0, DT_INT32)
.Input(b.node()->name(), 0, DT_FLOAT)
.Input(c.node()->name(), 0, DT_INT32)
.Input(d.node()->name(), 0, DT_FLOAT)
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
.Device("/gpu:0")
.Attr(kXlaClusterIdAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
Status status;
Node* launch = scope.graph()->AddNode(def, &status);
@ -107,7 +107,7 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
node->AddAttr(kXlaClusterIdAttr, "launch0");
node->set_requested_device("/gpu:0");
};
@ -155,8 +155,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
: ops::Add(scope.WithOpName("E"), a1, a0);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
"launch0");
node->AddAttr(kXlaClusterIdAttr, "launch0");
};
add_attrs(e.node());
@ -216,7 +215,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
node->AddAttr(kXlaClusterIdAttr, "launch0");
node->set_requested_device("/gpu:0");
};

View File

@ -268,4 +268,10 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static std::atomic<bool> xla_compilation_disabled(false);
void DisableXlaCompilation() { xla_compilation_disabled = true; }
bool FailOnXlaCompilation() { return xla_compilation_disabled; }
} // namespace tensorflow

View File

@ -162,6 +162,13 @@ MlirCommonFlags* GetMlirCommonFlags();
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Disables XLA compilation, forces it to return an error message instead. Can
// be used by a server to ensure that JIT compilation is opt-in.
void DisableXlaCompilation();
// Returns `false` unless `DisableXlaCompilation` was called.
bool FailOnXlaCompilation();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -158,7 +158,7 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
constants_(constants),
resources_(resources),
function_(function),
platform_info_(XlaPlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
has_ref_vars_(has_ref_vars) {}
static Status CompileToLocalExecutable(
@ -180,7 +180,7 @@ static Status CompileToLocalExecutable(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
return BuildXlaCompilationCache(ctx, platform_info, cache);
return BuildXlaCompilationCache(ctx->device(), platform_info, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
@ -191,7 +191,9 @@ static Status CompileToLocalExecutable(
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options = GenerateCompilerOptions(
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
*cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {
@ -248,8 +250,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context(
@ -373,7 +377,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
must_compile_(MustCompileAttr(ctx)),
has_ref_vars_(HasRefVars(ctx)) {}
@ -461,7 +465,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name();
@ -472,8 +476,10 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
XlaExecutableClosureStore::Global()->Consume(key);
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
int device_ordinal = stream ? stream->parent()->device_ordinal()

View File

@ -1196,12 +1196,9 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}
DeviceType jit_device_type(registration->compilation_device_name);
RecursiveCompilabilityChecker::OperationFilter op_filter =
CreateOperationFilter(*registration);
if (!RecursiveCompilabilityChecker{&op_filter, &jit_device_type}
if (!RecursiveCompilabilityChecker{
CreateOperationFilter(*registration),
DeviceType{registration->compilation_device_name}}
.IsCompilableNode(*node, lib_runtime)) {
continue;
}
@ -1718,7 +1715,6 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
&registration));
DeviceType jit_device_type(registration->compilation_device_name);
// We can always *compile* resource operations, stateful RNGs and dummy ops,
// even if we are sometimes unable to auto-cluster them.
@ -1733,7 +1729,8 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
op_filter.allow_slow_ops = true;
op_filter.allow_inaccurate_ops = true;
RecursiveCompilabilityChecker checker{&op_filter, &jit_device_type};
RecursiveCompilabilityChecker checker{
op_filter, DeviceType{registration->compilation_device_name}};
if (!uncompilable_node_info) {
// We do not need uncompilable node info. Just return the result.
return checker.IsCompilableCall(ndef, flr);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/base/call_once.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
@ -323,6 +324,10 @@ Status XlaCompilationCache::CompileImpl(
absl::optional<int64> compile_threshold,
const XlaCompiler::CompilationResult** out_compilation_result,
xla::LocalExecutable** out_executable) {
if (FailOnXlaCompilation()) {
return errors::Internal("XLA compilation disabled");
}
DCHECK_NE(out_executable, nullptr);
VLOG(2) << "XlaCompilationCache::Compile " << DebugString();

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -52,6 +54,30 @@ TEST(XlaCompilationCacheTest, SignatureEquality) {
}
}
TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) {
NameAttrList fn;
fn.set_name("afunction");
DisableXlaCompilation();
xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
DeviceType device_type = DeviceType(DEVICE_CPU_XLA_JIT);
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
auto cache = new XlaCompilationCache(client, device_type);
core::ScopedUnref cache_ref(cache);
Status status = cache->Compile(XlaCompiler::Options{}, fn, {},
XlaCompiler::CompileOptions{},
XlaCompilationCache::CompileMode::kStrict,
&compilation_result, &executable);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(
absl::StrContains(status.error_message(), "XLA compilation disabled"));
}
static void BM_BuildSignature(int iters, int n_args) {
NameAttrList fn;
fn.set_name("afunction");

View File

@ -49,8 +49,10 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_);
XlaComputationLaunchContext launch_context(
client, allocator, client->default_device_ordinal(),
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
@ -157,13 +159,16 @@ Status XlaCompileOnDemandOp::Compile(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", cache,
[&](XlaCompilationCache** write_into_cache) {
return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
return BuildXlaCompilationCache(ctx->device(), platform_info_,
write_into_cache);
}));
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(**cache, ctx, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::Options options = GenerateCompilerOptions(
**cache, *ctx->function_library(), ctx->device(),
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;

View File

@ -37,7 +37,8 @@ namespace tensorflow {
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
: OpKernel(ctx),
platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
void Compute(OpKernelContext* ctx) override;
private:

View File

@ -51,7 +51,7 @@ Status XlaCpuDeviceFactory::CreateDevices(
std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}
bool compile_on_demand = flags->tf_xla_compile_on_demand;

View File

@ -94,6 +94,11 @@ class XlaDevice : public LocalDevice {
static Status GetMetadata(OpKernelConstruction* ctx,
const Metadata** metadata);
// Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
// `device`.
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
struct Options {
// The StreamExecutor platform. Not owned. Must be non-null.
se::Platform* platform = nullptr;
@ -196,8 +201,6 @@ class XlaDevice : public LocalDevice {
xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
Status MakeTensorFromProto(XlaDeviceContext* device_context,
const TensorProto& tensor_proto,

View File

@ -66,7 +66,7 @@ class XlaGpuDeviceFactory : public DeviceFactory {
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
if (!flags->tf_xla_enable_xla_devices) {
LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
return Status::OK();
}

View File

@ -44,12 +44,6 @@ namespace {
using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
const char kPossibleNonVariableResourceHintMessage[] =
"If the error is similar to `Trying to access resource using the wrong "
"type`, this is likely because XLA only accepts Resource Variables as "
"inputs by snapshotting their values. Other TensorFlow resource types like "
"TensorList/TensorArray/Stack are not supported. Try removing non-variable "
"resource inputs to XLA.";
} // anonymous namespace
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
Status BuildXlaCompilationCache(OpKernelContext* ctx,
Status BuildXlaCompilationCache(DeviceBase* device,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
@ -59,7 +59,7 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx,
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
device->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
@ -75,21 +75,21 @@ Status BuildXlaCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) {
auto device = static_cast<Device*>(device_base);
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
if (device->device_type() == DEVICE_CPU) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
} else if (device->device_type() == DEVICE_GPU) {
platform_id = device->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
} else if (XlaDevice::GetMetadataFromDevice(device, &xla_device_metadata)
.ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
@ -104,47 +104,46 @@ XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
return XlaPlatformInfo(DeviceType(device->device_type()), platform_id,
xla_device_metadata, custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
if (!stream) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
tf_allocator_adapter->emplace(device->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
tf_allocator_adapter->emplace(device->GetAllocator({}), stream);
return &tf_allocator_adapter->value();
}
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
CHECK(ctx->function_library());
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache.client());
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
if (stream != nullptr) {
options.device_ordinal = stream->parent()->device_ordinal();
}
options.device_type = cache.device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.flib_def = function_library.GetFunctionLibraryDefinition();
options.graph_def_version = function_library.graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, ctx, platform_info);
GetAllocator(tf_allocator_adapter, device, stream, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();

View File

@ -80,27 +80,31 @@ class XlaPlatformInfo {
};
// Returns created XLA compilation cache.
Status BuildXlaCompilationCache(OpKernelContext* ctx,
Status BuildXlaCompilationCache(DeviceBase* dev,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache);
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
// Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
//
// `stream` parameter is nullable when running on host.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
DeviceBase* device, se::Stream* stream,
const XlaPlatformInfo& platform_info);
// Returns created options for the XLA compiler, and writes the used allocator
// into `tf_allocator_adapter`.
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
const XlaCompilationCache& cache,
const FunctionLibraryRuntime& function_library, DeviceBase* device,
se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
} // namespace tensorflow

View File

@ -24,11 +24,40 @@ filegroup(
srcs = glob(["**/*.td"]),
)
cc_library(
name = "string_container_utils",
hdrs = ["utils/string_container_utils.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
],
)
cc_library(
name = "array_container_utils",
hdrs = ["utils/array_container_utils.h"],
deps = [
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
],
)
cc_library(
name = "name_utils",
srcs = ["utils/name_utils.cc"],
hdrs = ["utils/name_utils.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "op_or_arg_name_mapper",
srcs = ["op_or_arg_name_mapper.cc"],
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
":name_utils",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",

View File

@ -341,6 +341,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
@ -348,6 +349,22 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "mhlo_control_flow_to_scf",
srcs = ["lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [
":hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
cc_library(
name = "map_lmhlo_to_scalar_op",
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"],
@ -800,6 +817,7 @@ cc_library(
":lhlo_legalize_to_affine",
":lhlo_legalize_to_gpu",
":lhlo_legalize_to_parallel_loops",
":mhlo_control_flow_to_scf",
":mhlo_fusion",
":mhlo_to_mhlo_lowering_patterns",
":sink_constants_to_control_flow",

View File

@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
let constructor = "createLegalizeControlFlowPass()";
}
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
let summary = "Legalize from MHLO control flow to SCF control flow.";
let constructor = "createControlFlowToScfPass()";
}
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
let summary = "Legalizes gathers to a torch index select.";
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";

View File

@ -35,6 +35,9 @@ namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
/// Lowers MHLO control flow ops to the SCF dialect.
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();

View File

@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion
add_mlir_library(MhloToStandard
legalize_control_flow.cc
legalize_to_standard.cc
mhlo_control_flow_to_scf.cc
DEPENDS
MLIRhlo_opsIncGen

View File

@ -0,0 +1,199 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#define DEBUG_TYPE "mhlo-control-flow-to-scf"
namespace mlir {
namespace mhlo {
namespace {
/// Convert MHLO While to SCF.
void MatchAndRewrite(WhileOp whileOp);
/// Pass that converts MHLO control flow to SCF.
class ControlFlowToScfPass
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<scf::SCFDialect>();
}
void runOnFunction() override {
getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
}
};
// TODO(jpienaar): Look into reformulating as a pattern.
void MatchAndRewrite(WhileOp whileOp) {
// Handle pattern:
// x = start
// step = ...
// limit = ...
// while (x < limit) { ... x += step; }
// Only handling multi value while loops at the moment.
auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
if (!tupleOp) return;
auto bodyReturn = whileOp.body()
.front()
.getTerminator()
->getOperand(0)
.getDefiningOp<mhlo::TupleOp>();
// Note: due to the shape restrictions on While, if the operand to While is a
// tuple, then so is the return type of the body. But the verifier isn't
// checking that at the moment, so just bail out here if this doesn't hold.
if (!bodyReturn) return;
Value result = whileOp.cond().front().getTerminator()->getOperand(0);
// TODO(jpienaar): Expand to handle more than simple case with LT compare and
// constant step.
auto cmp = result.getDefiningOp<mhlo::CompareOp>();
if (!cmp || cmp.comparison_direction() != "LT") return;
const int kConstant = -1;
auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
if (matchPattern(val, m_Constant())) return {val, kConstant};
// If it is defined by a tuple, then the tuple has to have been fed in and
// the external value is captured.
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
int index = gte.index().getSExtValue();
return {tupleOp.getOperand(index), index};
}
return {nullptr, 0};
};
using ValueIndex = std::pair<Value, int>;
ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
ValueIndex max = getValueAndIndex(cmp.rhs());
if (!loopIndVar.first || !max.first) return;
auto add =
bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
if (!add) return;
ValueIndex step = getValueAndIndex(add.rhs());
if (step.second != kConstant || !step.first) return;
// Only handle case where tuple isn't propagated as is for now.
// TODO(jpienaar): Remove this when a tuple is also created inside the loop
// to propagate.
for (auto* use : whileOp.body().front().getArgument(0).getUsers())
if (!isa<GetTupleElementOp>(use)) return;
LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
<< max.second << " step = " << step.second << "\n";
llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
<< max.first << " step = " << step.first << "\n";);
OpBuilder b(whileOp);
// Inputs to new for loop.
llvm::SmallVector<Value, 4> input;
input.reserve(tupleOp.getNumOperands());
for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
input.push_back(r);
for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
input.push_back(r);
auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
auto getAsIndex = [&](Value val) {
auto loc = whileOp.getLoc();
return b.create<ExtractElementOp>(
loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
};
// SCF for uses index type, so converted these.
auto forloopIndVar = getAsIndex(loopIndVar.first);
auto forMax = getAsIndex(max.first);
auto forStep = getAsIndex(step.first);
auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
forMax, forStep, input);
// Transfer the body without the block arguments.
forOp.getLoopBody().front().getOperations().splice(
forOp.getLoopBody().front().getOperations().end(),
whileOp.body().front().getOperations());
b.setInsertionPointToStart(&forOp.getLoopBody().front());
auto loopIndVarElType =
loopIndVar.first.getType().cast<ShapedType>().getElementType();
Value indVar = b.create<SplatOp>(
whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
forOp.getInductionVar()));
// Update all block argument users to the SCF For args.
for (auto* use :
llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
// TODO(jpienaar): Expand here too when we allow using the tuple in the
// loop.
auto gte = cast<GetTupleElementOp>(use);
// If the loop induction var, then refer to the loop induction variable as
// this operand is not updated.
if (gte.index() == loopIndVar.second) {
use->getResult(0).replaceAllUsesWith(indVar);
use->erase();
continue;
}
int index = gte.index().getSExtValue();
// If after the loop induction variable, then decrement as we don't include
// the loop induction variable in the for iter operands.
if (index > loopIndVar.second) --index;
use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
use->erase();
}
// Create new yield op without induction var update.
SmallVector<Value, 4> newYieldOps;
newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
newYieldOps.push_back(r);
for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
newYieldOps.push_back(r);
// Delete return & tuple op.
forOp.getLoopBody().front().back().erase();
forOp.getLoopBody().front().back().erase();
b.setInsertionPointToEnd(&forOp.getLoopBody().front());
b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
// Recombine output tuple with max value of induction variable.
llvm::SmallVector<Value, 4> loopOut;
loopOut.reserve(forOp.getNumResults() + 1);
for (auto r : forOp.getResults().take_front(loopIndVar.second))
loopOut.push_back(r);
loopOut.push_back(max.first);
for (auto r : forOp.getResults().drop_front(loopIndVar.second))
loopOut.push_back(r);
b.setInsertionPoint(whileOp);
auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
whileOp.replaceAllUsesWith(newRes.getOperation());
whileOp.erase();
}
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
return std::make_unique<ControlFlowToScfPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -0,0 +1,38 @@
// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s
func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) {
%cst = constant dense<-1> : tensor<i32>
%cst_0 = constant dense<1> : tensor<i32>
%cst_1 = constant dense<0> : tensor<i32>
%cst_2 = constant dense<1000> : tensor<i32>
%0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
%1 = "mhlo.while"(%0) ( {
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"mhlo.return"(%4) : (tensor<i1>) -> ()
}, {
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%3 = mhlo.add %2, %cst_0 : tensor<i32>
%4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%6 = "mhlo.tuple"(%3, %4, %5) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
"mhlo.return"(%6) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> ()
}) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
return %1 : tuple<tensor<i32>, tensor<i32>, tensor<i32>>
}
// CHECK-LABEL: func @lt_loop(
// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor<i32>
// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])

View File

@ -1029,14 +1029,49 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<i32
// CHECK: "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 2 : i32} : (tensor<1x4x3x3xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>)
}
func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} :
(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul
// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_0:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} :
(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed_a
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_2:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} :
(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed
// CHECK-LABEL: matmul_transposed_b
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} :
(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32>
return %0 : tensor<40x40xf32>
// CHECK-LABEL: matmul_transposed_ab
// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32>
// CHECK: %[[CST_1:.*]] = constant unit
// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>

View File

@ -66,7 +66,6 @@ namespace TFL {
// The actual LegalizeTF Pass.
namespace {
using xla::Status;
using xla::StatusOr;
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
@ -232,26 +231,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
return success();
}
// The following is effectively:
// def : Pat<
// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a,
// ConstBoolAttrTrue:$transpose_b),
// (TFL_FullyConnectedOp:$__0 $a, $b,
// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>;
LogicalResult ConvertTFMatMulOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_matmul_op = cast<TF::MatMulOp>(op);
if (tf_matmul_op.transpose_a()) return failure();
if (!tf_matmul_op.transpose_b()) return failure();
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
auto transpose = [&](Value input) -> std::pair<LogicalResult, Value> {
RankedTensorType type =
input.getType().dyn_cast_or_null<RankedTensorType>();
if (!type || type.getRank() != 2) return {failure(), nullptr};
auto permute_attr = DenseIntElementsAttr::get(
RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0});
auto permute = rewriter.create<ConstantOp>(
op->getLoc(), permute_attr.getType(), permute_attr);
llvm::SmallVector<int64_t, 2> new_shape{type.getShape()[1],
type.getShape()[0]};
auto output = rewriter.create<TFL::TransposeOp>(
op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()),
input, permute);
return {success(), output};
};
// TODO(jpienaar): Remove once handled via dailect conversion.
if (tf_matmul_op.transpose_a()) {
LogicalResult result = success();
std::tie(result, lhs) = transpose(lhs);
if (failed(result)) return failure();
}
if (!tf_matmul_op.transpose_b()) {
LogicalResult result = success();
std::tie(result, rhs) = transpose(rhs);
if (failed(result)) return failure();
}
Type output_type = tf_matmul_op.getResult().getType();
// TODO(jpienaar): Follow up post shuffle discussion.
auto no_input = rewriter.create<ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
auto fc_op = rewriter.create<FullyConnectedOp>(
op->getLoc(), ArrayRef<Type>{output_type}, op->getOperand(0),
op->getOperand(1), no_input, rewriter.getStringAttr("NONE"),
rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false));
op->getLoc(), ArrayRef<Type>{output_type}, lhs, rhs, no_input,
rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"),
rewriter.getBoolAttr(false));
rewriter.replaceOp(op, {fc_op.getResult(0)});
return success();
}

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/mlir/utils/name_utils.h"
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return absl::string_view(ref.data(), ref.size());
@ -103,62 +104,16 @@ int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }
namespace {
// Derives name from location.
std::string GetNameFromLoc(mlir::Location loc) {
llvm::SmallVector<llvm::StringRef, 8> loc_names;
llvm::SmallVector<mlir::Location, 8> locs;
locs.push_back(loc);
bool names_is_nonempty = false;
while (!locs.empty()) {
mlir::Location curr_loc = locs.pop_back_val();
if (auto name_loc = curr_loc.dyn_cast<mlir::NameLoc>()) {
// Add name in NameLoc. For NameLoc we also account for names due to ops
// in functions where the op's name is first.
auto name = name_loc.getName().strref().split('@').first;
loc_names.push_back(name);
if (!name.empty()) names_is_nonempty = true;
continue;
} else if (auto call_loc = curr_loc.dyn_cast<mlir::CallSiteLoc>()) {
// Add name if CallSiteLoc's callee has a NameLoc (as should be the
// case if imported with DebugInfo).
if (auto name_loc = call_loc.getCallee().dyn_cast<mlir::NameLoc>()) {
auto name = name_loc.getName().strref().split('@').first;
loc_names.push_back(name);
if (!name.empty()) names_is_nonempty = true;
continue;
}
} else if (auto fused_loc = curr_loc.dyn_cast<mlir::FusedLoc>()) {
// Push all locations in FusedLoc in reverse order, so locations are
// visited based on order in FusedLoc.
auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
continue;
}
// Location is not a supported, so an empty StringRef is added.
loc_names.push_back(llvm::StringRef());
}
if (names_is_nonempty)
return llvm::join(loc_names.begin(), loc_names.end(), ";");
return "";
}
} // anonymous namespace
std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
auto name_from_loc = GetNameFromLoc(op->getLoc());
auto name_from_loc = mlir::GetNameFromLoc(op->getLoc());
if (!name_from_loc.empty()) return name_from_loc;
// If the location is none of the expected types, then simply use name
// generated using the op type.
return std::string(op->getName().getStringRef());
}
auto val = op_or_val.dyn_cast<mlir::Value>();
auto name_from_loc = GetNameFromLoc(val.getLoc());
auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
if (!name_from_loc.empty()) return name_from_loc;
// If the location is none of the expected types, then simply use name
// generated using the op type. Follow TF convention and append the result

View File

@ -794,6 +794,7 @@ cc_library(
"transforms/tpu_identity_pruning.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_outside_compilation_cluster.cc",
"transforms/tpu_resource_read_for_write.cc",
"transforms/tpu_rewrite_pass.cc",
"transforms/tpu_sharding_identification_pass.cc",
"transforms/tpu_space_to_depth_pass.cc",
@ -960,6 +961,7 @@ cc_library(
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/compiler/jit:shape_inference_helpers",
"//tensorflow/compiler/mlir:name_utils",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
"//tensorflow/compiler/xla:status_macros",

View File

@ -2,7 +2,6 @@ load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
"tfe_xla_copts",
)
package(
@ -20,7 +19,7 @@ tf_cuda_library(
srcs = [
"c_api_unified_experimental_mlir.cc",
],
copts = tf_copts() + tfe_xla_copts(),
copts = tf_copts(),
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:tensor_interface",

View File

@ -452,7 +452,8 @@ Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) {
return Unimplemented("SetAttrFloat has not been implemented yet.");
}
Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) {
return Unimplemented("SetAttrBool has not been implemented yet.");
attrs_[attr_name] = BoolAttr::get(value, context_);
return Status::OK();
}
Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) {

View File

@ -250,33 +250,6 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
// tf_executor.fetch
//===----------------------------------------------------------------------===//
namespace {
void Print(FetchOp fetch, OpAsmPrinter &p) {
p << fetch.getOperationName();
if (fetch.getNumOperands() > 0) {
p << ' ';
p.printOperands(fetch.operand_begin(), fetch.operand_end());
p << " : ";
interleaveComma(fetch.getOperandTypes(), p);
}
p.printOptionalAttrDict(fetch.getAttrs());
}
ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result.operands) ||
parser.parseOptionalAttrDict(result.attributes)
);
}
} // anonymous namespace
//===----------------------------------------------------------------------===//
// tf_executor.island
//===----------------------------------------------------------------------===//
@ -411,31 +384,6 @@ ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) {
// tf_executor.yield
//===----------------------------------------------------------------------===//
namespace {
void Print(YieldOp yield, OpAsmPrinter &p) {
p << yield.getOperationName();
if (yield.getNumOperands() > 0) {
p << ' ';
p.printOperands(yield.operand_begin(), yield.operand_end());
p << " : ";
interleaveComma(yield.getOperandTypes(), p);
}
p.printOptionalAttrDict(yield.getAttrs());
}
ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_info;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(op_info) ||
(!op_info.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(op_info, types, loc, result.operands) ||
parser.parseOptionalAttrDict(result.attributes));
}
} // anonymous namespace
//===----------------------------------------------------------------------===//
// tf_executor.Switch
//===----------------------------------------------------------------------===//
@ -848,23 +796,6 @@ LogicalResult Verify(NextIterationSourceOp source) {
return success();
}
void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) {
p << next_iteration.getOperationName() << " : " << next_iteration.getType(0);
p.printOptionalAttrDict(next_iteration.getAttrs());
}
ParseResult ParseNextIterationSourceOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<Type, 1> types;
if (parser.parseColonTypeList(types)) return failure();
MLIRContext *context = parser.getBuilder().getContext();
Type token_type = TokenType::get(context);
Type control_type = ControlType::get(context);
result.addTypes({types.front(), token_type, control_type});
return parser.parseOptionalAttrDict(result.attributes);
}
} // anonymous namespace
//===----------------------------------------------------------------------===//
@ -891,36 +822,6 @@ LogicalResult Verify(NextIterationSinkOp sink) {
return success();
}
void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
p << next_iteration.getOperationName() << " [";
p.printOperand(next_iteration.getOperand(0));
p << "] ";
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
p << " : " << next_iteration.getOperand(1).getType();
p.printOptionalAttrDict(next_iteration.getAttrs());
}
ParseResult ParseNextIterationSinkOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
llvm::SMLoc loc = parser.getCurrentLocation();
// First type is always the token consumed from the NextIteration.source
Type token_type = TokenType::get(parser.getBuilder().getContext());
SmallVector<Type, 1> types = {token_type};
if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) ||
parser.parseOperandList(op_infos) || parser.parseColonTypeList(types))
return failure();
Type control_type = ControlType::get(parser.getBuilder().getContext());
types.append(op_infos.size() - 2, control_type);
if (parser.resolveOperands(op_infos, types, loc, result.operands))
return failure();
return parser.parseOptionalAttrDict(result.attributes);
}
} // anonymous namespace
//===----------------------------------------------------------------------===//
@ -959,32 +860,6 @@ ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) {
// tf_executor.ControlTrigger
//===----------------------------------------------------------------------===//
namespace {
void Print(ControlTriggerOp trigger, OpAsmPrinter &p) {
p << trigger.getOperationName() << ' ';
p.printOperands(trigger.getOperands());
p.printOptionalAttrDict(trigger.getAttrs());
}
ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> op_infos;
SmallVector<Type, 1> types;
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseOperandList(op_infos)) return failure();
Type control_type = ControlType::get(parser.getBuilder().getContext());
types.append(op_infos.size(), control_type);
if (parser.resolveOperands(op_infos, types, loc, result.operands))
return failure();
// Single control as the only output
result.types.push_back(control_type);
return parser.parseOptionalAttrDict(result.attributes);
}
} // anonymous namespace
//===----------------------------------------------------------------------===//
// tf_executor.LoopCond
//===----------------------------------------------------------------------===//

View File

@ -47,10 +47,12 @@ def TfExecutor_Dialect : Dialect {
}
// Control type.
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">;
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">,
BuildableType<"$_builder.getType<ControlType>()">;
// Token type.
def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">;
def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">,
BuildableType<"$_builder.getType<TokenType>()">;
// TODO(hinsu): Define and use TensorType instead of AnyType for data operands
// and results. For example, MergeOp output type.
@ -148,7 +150,11 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch",
}]>
];
let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
let verifier = ?;
let printer = ?;
let parser = ?;
}
def TfExecutor_IslandOp : TfExecutor_Op<"island",
@ -229,7 +235,11 @@ def TfExecutor_YieldOp : TfExecutor_Op<"yield",
}]>
];
let assemblyFormat = "($fetches^ `:` type($fetches))? attr-dict";
let verifier = ?;
let printer = ?;
let parser = ?;
}
def TfExecutor_SwitchOp : TfExecutor_Op<"Switch",
@ -466,6 +476,10 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source",
}
}];
let assemblyFormat = "`:` type($output) attr-dict";
let printer = ?;
let parser = ?;
}
@ -527,6 +541,11 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink",
result.attributes.append(attributes.begin(), attributes.end());
}]>
];
let assemblyFormat = " `[` $token `]` $input (`,` $controlInputs^)? `:` type($input) attr-dict";
let printer = ?;
let parser = ?;
}
def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
@ -552,7 +571,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit",
.Attr("T: type")
For example:
%1:2 = tf_executor.Exit %0#0 {T: "tfdtype$DT_INT32"} : tensor<*xi32>
%1:2 = tf_executor.Exit %0#0 : tensor<*xi32> {T: "tfdtype$DT_INT32"}
Note: Additional result corresponds to the control output.
}];
@ -607,6 +626,11 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger",
result.attributes.append(attributes.begin(), attributes.end());
}]>
];
let assemblyFormat = "$controlInputs attr-dict";
let printer = ?;
let parser = ?;
}
def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond",

View File

@ -52,6 +52,12 @@ an output element, this operation computes \\(y = |x|\\).
def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes acos of x element-wise.";
let description = [{
Provided an input tensor, the `tf.math.acos` operation returns the inverse cosine of each element of the tensor. If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`.
Input range is `[-1, 1]` and the output has a range of `[0, pi]`.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
);
@ -94,6 +100,10 @@ def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutA
let description = [{
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
Given two input tensors, the `tf.add` operation computes the sum for every element in the tensor.
Both input and output have a range `(-inf, inf)`.
}];
let arguments = (ins
@ -136,31 +146,6 @@ Inputs must be of same size and shape.
let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
let description = [{
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$y
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
let summary = "Adjust the contrast of one or more images.";
@ -1740,6 +1725,24 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
def TF_ConfigureDistributedTPUOp : TF_Op<"ConfigureDistributedTPU", []> {
let summary = [{
Sets up the centralized structures for a distributed TPU system.
}];
let arguments = (ins
StrAttr:$embedding_config,
StrAttr:$tpu_embedding_config,
DefaultValuedAttr<BoolAttr, "false">:$is_global_init,
DefaultValuedAttr<BoolAttr, "false">:$enable_whole_mesh_compilations,
DefaultValuedAttr<BoolAttr, "true">:$compilation_failure_closes_chips
);
let results = (outs
TF_StrTensor:$topology
);
}
def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns the complex conjugate of a complex number.";
@ -2786,27 +2789,6 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOpe
let hasFolder = 1;
}
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns 0 if the denominator is zero.";
let description = [{
*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x,
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> {
let summary = [{
Interleave the values from the `data` tensors into a single tensor.
@ -3853,6 +3835,95 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
}
def TF_FusedBatchNormV2Op : TF_Op<"FusedBatchNormV2", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$x,
F32Tensor:$scale,
F32Tensor:$offset,
F32Tensor:$mean,
F32Tensor:$variance,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$x,
F32Tensor:$scale,
F32Tensor:$offset,
F32Tensor:$mean,
F32Tensor:$variance,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
let summary = "Gather slices from `params` according to `indices`.";
@ -6213,14 +6284,14 @@ retained with length 1.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -6324,27 +6395,6 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> {
}];
}
def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
let description = [{
*NOTE*: `Maximum` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$y
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
let summary = "Computes the mean of elements across dimensions of a tensor.";
@ -6440,14 +6490,14 @@ retained with length 1.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input,
TF_I32OrI64Tensor:$reduction_indices,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
@ -7878,33 +7928,6 @@ tf.real(input) ==> [-2.25, 3.25]
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
}
def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x / y element-wise for real types.";
let description = [{
If `x` and `y` are reals, this will return the floating-point division.
*NOTE*: `Div` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$x,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$y
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint8]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the reciprocal of x element-wise.";
@ -9314,6 +9337,18 @@ Generate a sharded filename. The filename is printf formatted as
);
}
def TF_ShutdownDistributedTPUOp : TF_Op<"ShutdownDistributedTPU", []> {
let summary = "Shuts down a running distributed TPU system.";
let description = [{
The op returns an error if no system is running.
}];
let arguments = (ins);
let results = (outs);
}
def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes sigmoid of `x` element-wise.";
@ -9832,6 +9867,41 @@ backpropagation,
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> {
let summary = [{
Multiply matrix "a" by matrix "b".
}];
let description = [{
The inputs must be two-dimensional matrices and the inner dimension of "a" must
match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not
`SparseTensor`s. This op is optimized for the case where at least one of "a" or
"b" is sparse, in the sense that they have a large proportion of zero values.
The breakeven for using this versus a dense matrix multiply on one platform was
30% zero values in the sparse matrix.
The gradient computation of this operation will only take advantage of sparsity
in the input gradient when that gradient comes from a Relu.
}];
let arguments = (ins
TensorOf<[BF16, F32]>:$a,
TensorOf<[BF16, F32]>:$b,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
DefaultValuedAttr<BoolAttr, "false">:$a_is_sparse,
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse
);
let results = (outs
F32Tensor:$product
);
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>;
}
def TF_SparseReshapeOp : TF_Op<"SparseReshape", [NoSideEffect]> {
let summary = [{
Reshapes a SparseTensor to represent values in a new dense shape.
@ -11625,9 +11695,9 @@ array([[1, 2, 3, 1, 2, 3],
TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
// TODO(parkers): Add folds for multiples = [1,...].
// TODO(parkers): Add errors for negative multiples and multiples.size() !=
// input.rank()
let verifier = [{ return Verify(*this); }];
let hasFolder = 1;
}
def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> {
@ -12893,7 +12963,8 @@ create these operators.
DefaultValuedAttr<I64ArrayAttr, "{1, 1, 1, 1}">:$dilations,
DefaultValuedAttr<BoolAttr, "true">:$use_cudnn_on_gpu,
DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "0.2f">:$leakyrelu_alpha
);
let results = (outs

View File

@ -157,20 +157,10 @@ class TF_TensorFlowType <string name, string description> :
"TensorFlow " # description # " type">,
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
// Any tensor element type allowed in TensorFlow ops
def TF_ElementType : Type<Or<[AnyFloat.predicate,
AnySignlessInteger.predicate,
AnyUnsignedInteger.predicate,
AnyComplex.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;
// Any TensorFlow tensor type
def TF_Tensor : TensorOf<[TF_ElementType]>;
//===----------------------------------------------------------------------===//
// Integer types
// TODO(mgester) shouldn't this be SignedIntOfWidths?
def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
@ -191,10 +181,11 @@ def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
// Any signed integer type
// TODO(mgester) shouldn't this be SignedIntOfWidths?
def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
// Any integer type
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>;
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
// Any integer tensor types
def TF_IntTensor : TensorOf<[TF_Int]>;
@ -208,8 +199,8 @@ def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">;
def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">;
// Any quantized type
def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
TF_Quint16]>;
def TF_Quantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
TF_Quint16], "quantized">;
//===----------------------------------------------------------------------===//
// Floating-point types
@ -217,8 +208,10 @@ def TF_F32Or64 : FloatOfWidths<[32, 64]>;
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
def TF_Float : AnyTypeOf<[F16, F32, F64, BF16], "floating-point">;
// Any floating-point tensor types
def TF_FpTensor : TensorOf<[AnyFloat]>;
def TF_FpTensor : TensorOf<[TF_Float]>;
//===----------------------------------------------------------------------===//
// Complex types
@ -231,10 +224,9 @@ def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
def TF_Complex128 : Complex<F<64>>;
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128],
"64/128-bit complex type">;
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>;
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
//===----------------------------------------------------------------------===//
// String/variant/resource types
@ -248,28 +240,113 @@ def TF_VariantTensor : TensorOf<[TF_Variant]>;
def TF_Resource : TF_TensorFlowType<"Resource", "resource">;
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
//===----------------------------------------------------------------------===//
// Reference types
// Float reference types
def TF_F16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
def TF_F32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
def TF_F64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
// Any float reference type
def TF_FloatRef : AnyTypeOf<[TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_Bfloat16Ref],
"floating-point reference">;
// Complex reference types
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
// Any complex reference type
def TF_ComplexRef : AnyTypeOf<[TF_Complex64Ref, TF_Complex128Ref], "complex reference">;
// Integer reference types
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
// Any signed integer reference type
def TF_SIntRef : AnyTypeOf<[TF_Int8Ref, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref],
"signed integer reference">;
// Any unsigned integer reference type
def TF_UIntRef : AnyTypeOf<[TF_Uint8Ref, TF_Uint16Ref, TF_Uint32Ref,
TF_Uint64Ref], "unsigned integer reference">;
// Any integer reference type
def TF_IntRef : AnyTypeOf<[TF_SIntRef, TF_UIntRef], "integer reference">;
// Quantized reference types
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
// Any quantized reference type
def TF_QuantizedRef : AnyTypeOf<[TF_Qint8Ref, TF_Qint16Ref, TF_Qint32Ref,
TF_Quint8Ref, TF_Quint16Ref], "quantized reference">;
// Other reference types
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
def TF_StringRef : TF_TensorFlowType<"StringRef", "stringref">;
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
// Reference tensor types
def TF_FpRefTensor : TensorOf<[TF_FloatRef]>;
def TF_I32OrI64RefTensor : TensorOf<[TF_Int32Ref, TF_Int64Ref]>;
//===----------------------------------------------------------------------===//
// Multi-category type constraints
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32Or64]>;
// Any integer or floating-point tensor types
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>;
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
"number">;
def TF_Number : AnyTypeOf<[TF_Int, TF_Float, TF_Quantized, TF_Complex],
"number">;
def TF_NumberRef : AnyTypeOf<[TF_IntRef, TF_FloatRef, TF_QuantizedRef,
TF_ComplexRef], "number reference">;
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
def TF_NumberTensor : TensorOf<[TF_Number]>;
def TF_NumberRefTensor : TensorOf<[TF_NumberRef]>;
def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>;
def TF_NumberOrStr : AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8,
TF_Str]>;
def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>;
//===----------------------------------------------------------------------===//
// Tensor and tensor element types
// Bool type
def TF_Bool : I<1>;
// Any tensor element type allowed in TensorFlow ops
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
def TF_ElementType : Type<Or<[TF_Float.predicate,
TF_Complex.predicate,
TF_Int.predicate,
TF_Bool.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;
// Any TensorFlow tensor type
def TF_Tensor : TensorOf<[TF_ElementType]>;
//===----------------------------------------------------------------------===//
// TensorFlow attribute definitions
//===----------------------------------------------------------------------===//

View File

@ -570,36 +570,6 @@ def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]
DerivedAttr shape = TF_DerivedResultShapeAttr;
}
def TF_SparseMatMulOp : TF_Op<"SparseMatMul", [NoSideEffect]> {
let summary = [{
SparseMatMul is MatMul with hints on the sparseness of the matrices.
}];
let description = [{
Similar to MatMul, with a_is_sparse and b_is_sparse indicating whether a and b
are sparse matrices.
}];
let arguments = (ins
TensorOf<[BF16, F32]>:$a,
TensorOf<[BF16, F32]>:$b,
DefaultValuedAttr<BoolAttr, "true">:$a_is_sparse,
DefaultValuedAttr<BoolAttr, "false">:$b_is_sparse,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b
);
let results = (outs
TensorOf<[F32]>:$product
);
TF_DerivedOperandTypeAttr Ta = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tb = TF_DerivedOperandTypeAttr<1>;
}
def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall",
[CallOpInterface]> {
let summary =
@ -1213,63 +1183,6 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
let verifier = [{ return VerifyPartitionedCall(*this); }];
}
class TF_FusedBatchNormOpBase<string Name> : TF_Op<Name, [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Batch normalization.";
let description = [{
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$x,
F32Tensor:$scale,
F32Tensor:$offset,
F32Tensor:$mean,
F32Tensor:$variance,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<F32Attr, "1.0f">:$exponential_avg_factor,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2
);
}
def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);
}
def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> {
let summary = [{
Batches all the inputs tensors to the computation done by the function.
@ -1341,4 +1254,98 @@ must be a Tensor or a list/tuple of Tensors.
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
}
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
let description = [{
*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$y
);
let results = (outs
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint8Ref]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns 0 if the denominator is zero.";
let description = [{
*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$x,
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$y
);
let results = (outs
TensorOf<[F16, F32, F64, TF_Complex, TF_F16Ref, TF_F32Ref, TF_F64Ref, TF_ComplexRef]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise.";
let description = [{
*NOTE*: `Maximum` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$x,
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$y
);
let results = (outs
TensorOf<[TF_Float, I16, I32, I64, TF_Uint8, TF_FloatRef, TF_Int16Ref, TF_Int32Ref, TF_Int64Ref, TF_Uint8Ref]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x / y element-wise for real types.";
let description = [{
If `x` and `y` are reals, this will return the floating-point division.
*NOTE*: `Div` supports broadcasting. More about broadcasting
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
}];
let arguments = (ins
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$x,
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$y
);
let results = (outs
TensorOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint16, TF_Uint8, TF_FloatRef, TF_SIntRef, TF_ComplexRef, TF_Uint16Ref, TF_Uint8Ref]>:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
#endif // TF_OPS

View File

@ -1783,6 +1783,87 @@ static LogicalResult Verify(TensorScatterUpdateOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// - input has at least rank 1
// - multiples is rank 1
// - multiples.size() == input.rank()
// - input.rank() == output.rank()
// - Elements in multiples are non-negative
// - input.shape[i] * multiples[i] == output.shape[i]
// for i in [0, input.rank() - 1]
static LogicalResult Verify(TileOp op) {
auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
auto multiples_type = op.multiples().getType().dyn_cast<RankedTensorType>();
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
if (multiples_type && multiples_type.getRank() != 1) {
return op.emitOpError() << "expected multiples to be rank 1, got rank = "
<< multiples_type.getRank();
}
if (input_type && multiples_type && multiples_type.hasStaticShape() &&
(input_type.getRank() != multiples_type.getNumElements() ||
(input_type.getRank() == 0 && multiples_type.getNumElements() == 1))) {
return op.emitOpError()
<< "expected size of multiples equal to rank of input"
<< ", got multiples of size " << multiples_type.getNumElements()
<< ", and input of rank " << input_type.getRank();
}
if (input_type && output_type) {
if (input_type.getRank() != output_type.getRank()) {
return op.emitOpError()
<< "expected rank of input to equal to rank of output"
<< ", got input of rank " << input_type.getRank()
<< ", and output of rank " << output_type.getRank();
}
DenseIntElementsAttr multiples_attr;
if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) {
for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) {
const int64_t input_dim = input_type.getDimSize(i);
const int64_t output_dim = output_type.getDimSize(i);
const int64_t m = multiples_attr.getValue<APInt>(i).getSExtValue();
if (m < 0) {
return op.emitOpError()
<< "expected multiples to be non-negative, got "
<< "multiples[" << i << "] = " << m;
}
if (!ShapedType::isDynamic(input_dim) &&
!ShapedType::isDynamic(output_dim) && output_dim != input_dim * m) {
return op.emitOpError()
<< "requires input.shape[" << i << "] (" << input_dim << ")"
<< " * " << m << " to be equal to "
<< "output.shape[" << i << "] (" << output_dim << ")";
}
}
}
}
return success();
}
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
DenseIntElementsAttr multiples_attr;
if (matchPattern(multiples(), m_Constant(&multiples_attr))) {
// Return input directly when multiples are all ones,
// regardless what input is.
if (multiples_attr.isSplat() &&
multiples_attr.getSplatValue<APInt>().getSExtValue() == 1) {
return input();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// TopKV2Op
//===----------------------------------------------------------------------===//

View File

@ -285,7 +285,7 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
// and certain tf_executor ops are added correctly.
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
// CHECK: tf_executor.NextIteration.Sink[{{.*}}] {{.*}}, %[[CONTROL]]
func @next_iteration_sink_control_input() {
tf_executor.graph {
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>

View File

@ -568,6 +568,14 @@ func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2:
return %0: tensor<*xf16>
}
// CHECK-LABEL: testTileMultiplesAllOnes
func @testTileMultiplesAllOnes(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%cst = constant dense <[1, 1]> : tensor<2xi32>
// CHECK: return %arg0
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
return %0: tensor<2x3xf32>
}
// CHECK-LABEL: testLogicalNotOfEqual
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>

View File

@ -220,7 +220,7 @@ func @merge_islands_only() {
%11:2 = tf_executor.island(%10#1) wraps "tf.opF"() : () -> tensor<i32>
%12:2 = tf_executor.island wraps "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%13 = tf_executor.ControlTrigger %2, %12#1, %9#1
tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32>
tf_executor.NextIteration.Sink[%3#1] %12#0, %13 : tensor<*xi32>
tf_executor.fetch
}
return
@ -244,7 +244,7 @@ func @merge_islands_only() {
// CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]])
// CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32>
// CHECK: %[[CT:.*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3_control]], %[[EXIT_control]]
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
// CHECK-NEXT: tf_executor.NextIteration.Sink[%[[NEXTIT_SRC_token]]] %[[ISLAND_3]], %[[CT]]
// Test no merging took place as cycle would be formed otherwise.

View File

@ -7,7 +7,7 @@
# CHECK: %[[NEXTITERATION:[a-z0-9]+]], %[[NEXTITERATION_token:[a-z0-9]+]], {{.*}} = tf_executor.NextIteration.Source
# CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]]
# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION_token]]]
# CHECK: tf_executor.NextIteration.Sink[%[[NEXTITERATION_token]]]
node {
name: "Const"

View File

@ -3468,3 +3468,85 @@ func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// -----
func @testTile(%arg0: tensor<2x3x?xf32>) {
%cst = constant dense <[2, 3, 4]> : tensor<3xi32>
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3x?xf32>, tensor<3xi32>) -> tensor<4x9x?xf32>
return
}
// -----
func @testTileMultipleNotRank1(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1xi32>) {
// expected-error @+1 {{expected multiples to be rank 1, got rank = 2}}
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<1x1xi32>) -> tensor<2x3xf32>
return
}
// -----
func @testTileInputRankNotEqualToMultiplesSize(%arg0: tensor<2x3xf32>, %arg1: tensor<3xi32>) {
// expected-error @+1 {{expected size of multiples equal to rank of input, got multiples of size 3, and input of rank 2}}
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<2x3xf32>
return
}
// -----
func @testTileInputRankNotEqualToOutputRank(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) {
// expected-error @+1 {{expected rank of input to equal to rank of output, got input of rank 2, and output of rank 3}}
%0 = "tf.Tile"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3x1xf32>
return
}
// -----
func @testTileNegativeMultiples(%arg0: tensor<2x3xf32>) {
%cst = constant dense <[-1, 1]> : tensor<2xi32>
// expected-error @+1 {{expected multiples to be non-negative, got multiples[0] = -1}}
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32>
return
}
// -----
func @testTileInvalidOutputShape(%arg0: tensor<2x3xf32>) {
%cst = constant dense <[2, 3]> : tensor<2xi32>
// expected-error @+1 {{requires input.shape[1] (3) * 3 to be equal to output.shape[1] (6)}}
%0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<4x6xf32>
return
}
// -----
// Test reference variable support for some ops (no errors expected)
// CHECK-LABEL: @testMaximumWithRef
func @testMaximumWithRef(%arg0: tensor<!tf.f32ref>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: tf.Maximum
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<!tf.f32ref>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: @testAddV2WithRef
func @testAddV2WithRef(%arg0: tensor<!tf.int16ref>, %arg1: tensor<i16>) -> tensor<i16> {
// CHECK: tf.AddV2
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<!tf.int16ref>, tensor<i16>) -> tensor<i16>
return %0 : tensor<i16>
}
// CHECK-LABEL: @testRealDivWithRef
func @testRealDivWithRef(%arg0: tensor<f64>, %arg1: tensor<!tf.f64ref>) -> tensor<f64> {
// CHECK: tf.RealDivOp
%0 = "tf.RealDivOp"(%arg0, %arg1) : (tensor<f64>, tensor<!tf.f64ref>) -> tensor<f64>
return %0 : tensor<f64>
}
// CHECK-LABEL: @testDivNoNanWithRef
func @testDivNoNanWithRef(%arg0: tensor<f32>, %arg1: tensor<!tf.f32ref>) -> tensor<f32> {
// CHECK: tf.DivNoNanOp
%0 = "tf.DivNoNanOp"(%arg0, %arg1) : (tensor<f32>, tensor<!tf.f32ref>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -433,7 +433,7 @@ func @nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> {
%1:3 = tf_executor.NextIteration.Source : tensor<*xf32>
tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32>
tf_executor.fetch %1#0 : tensor<*xf32>
}
return %0 : tensor<*xf32>
@ -445,7 +445,7 @@ func @nextiteration_with_attributes(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*
%1:3 = tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
tf_executor.NextIteration.Sink[%1#1] %1#0 : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32> {attr3 = 32 : i64, tf_executor.attr_fetch = "some_value"}
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}} : tensor<*xf32> {attr4 = 42 : i64, tf_executor.attr_push = "other_value"}
tf_executor.fetch %1#0 : tensor<*xf32>
}
return %0 : tensor<*xf32>
@ -457,9 +457,9 @@ func @nextiteration_control(%arg0: tensor<*xf32>, %arg1: tensor<i1>) -> tensor<*
%1:3 = tf_executor.Switch %arg0, %arg1 : tensor<*xf32>
%2:2 = tf_executor.Enter %arg0, %1#2, %1#2 frame "some/frame" : tensor<*xf32>
%3:3 = tf_executor.NextIteration.Source : tensor<*xf32>
tf_executor.NextIteration.Sink [%3#1] %3#0, %1#2 : tensor<*xf32>
tf_executor.NextIteration.Sink[%3#1] %3#0, %1#2 : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Source : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink [%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
// CHECK: tf_executor.NextIteration.Sink[%{{.*}}] %{{.*}}, %{{.*}} : tensor<*xf32>
tf_executor.fetch %3#0 : tensor<*xf32>
}
return %0 : tensor<*xf32>

View File

@ -0,0 +1,64 @@
// RUN: tf-opt -tf-tpu-resource-read-for-write %s | FileCheck %s --dump-input=always
// CHECK-LABEL: func @write_only_resource
// CHECK-SAME: ([[ARG0:%.*]]: tensor<i32>, [[ARG1:%.*]]: tensor<f32>, [[ARG2:%.*]]: tensor<*x!tf.resource<tensor<i32>>>)
func @write_only_resource(%arg0: tensor<i32>, %arg1: tensor<f32>, %arg2: tensor<*x!tf.resource<tensor<i32>>>) {
// CHECK-NEXT: [[READ:%.*]] = "tf.ReadVariableOp"([[ARG2]])
// CHECK-NEXT: [[CLUSTER:%.*]]:2 = "tf_device.cluster_func"([[ARG0]], [[ARG1]], [[READ]])
// CHECK-SAME: _tpu_replicate = "write"
%0:2 = "tf_device.cluster_func"(%arg0, %arg1) {_tpu_replicate = "write", func = @write_func} : (tensor<i32>, tensor<f32>) -> (tensor<f32>, tensor<i32>)
// CHECK-NEXT: "tf.AssignVariableOp"([[ARG2]], [[CLUSTER]]#1)
"tf.AssignVariableOp"(%arg2, %0#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
// CHECK-NEXT: return
return
}
// CHECK-LABEL: func @write_func
// CHECK-SAME: ({{%.*}}: tensor<i32>, {{%.*}}: tensor<f32>, {{%.*}}: tensor<i32>) -> (tensor<f32>, tensor<i32>)
func @write_func(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i32>) {
return %arg1, %arg0 : tensor<f32>, tensor<i32>
}
// CHECK-LABEL: func @read_write_resource
func @read_write_resource(%arg0: tensor<i32>, %arg1: tensor<f32>, %arg2: tensor<*x!tf.resource<tensor<i32>>>) {
// CHECK-COUNT-1: tf.ReadVariableOp
%0 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
%1:2 = "tf_device.cluster_func"(%arg0, %arg1, %0) {_tpu_replicate = "read_write", func = @read_write_func} : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<f32>, tensor<i32>)
"tf.AssignVariableOp"(%arg2, %1#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
// CHECK-LABEL: func @read_write_func
// CHECK-SAME: ({{%.*}}: tensor<i32>, {{%.*}}: tensor<f32>) -> (tensor<f32>, tensor<i32>)
func @read_write_func(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i32>) {
return %arg1, %arg0 : tensor<f32>, tensor<i32>
}
// CHECK-LABEL: func @multiple_write_resource
func @multiple_write_resource(%arg0: tensor<i32>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) {
// CHECK-NOT: tf.ReadVariableOp
%0:2 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_write", func = @multiple_write_func} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg1, %0#0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %0#1) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
// CHECK-LABEL: func @multiple_write_func
// CHECK-SAME: ({{%.*}}: tensor<i32>) -> (tensor<i32>, tensor<i32>)
func @multiple_write_func(%arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
return %arg0, %arg0 : tensor<i32>, tensor<i32>
}
// CHECK-LABEL: func @multiple_result_user
func @multiple_result_user(%arg0: tensor<i32>, %arg1: tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32> {
// CHECK-NOT: tf.ReadVariableOp
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "multiple_uses", func = @multiple_result_user_func} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%arg1, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return %0 : tensor<i32>
}
// CHECK-LABEL: func @multiple_result_user_func
// CHECK-SAME: ({{%.*}}: tensor<i32>) -> tensor<i32>
func @multiple_result_user_func(%arg0: tensor<i32>) -> tensor<i32> {
return %arg0 : tensor<i32>
}

View File

@ -173,7 +173,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @tail_single_outside_compiled_op() {
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: "tf.NoOp"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
@ -190,7 +190,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
"tf_device.cluster"() ( {
%a = "tf.A"() : () -> tensor<i32>
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.C"() : () -> ()
"tf.NoOp"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
@ -200,7 +200,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @tail_single_outside_compiled_op_user() -> tensor<i32> {
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: "tf.NoOp"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
@ -217,7 +217,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%cluster = "tf_device.cluster"() ( {
%a = "tf.A"() : () -> tensor<i32>
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
"tf.C"() : () -> ()
"tf.NoOp"() : () -> ()
tf_device.return %b : tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
// CHECK: return %[[LAUNCH_OUT]]
@ -262,7 +262,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%b = "tf.B"() : () -> tensor<i32>
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.E"
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.Const"
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
@ -279,7 +279,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%cluster:5 = "tf_device.cluster"() ( {
%c = "tf.C"() : () -> tensor<i32>
%d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%e = "tf.E"() : () -> tensor<i32>
%e = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %a, %b, %c, %d, %e : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
// CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1
@ -320,14 +320,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor<i32>) {
// CHECK-NOT: "tf_device.launch"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.Identity"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: "tf.Identity"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%a = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
%a = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
"tf.C"(%b) : (tensor<i32>) -> ()
%c = "tf.Identity"(%b) : (tensor<i32>) -> tensor<i32>
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
@ -379,7 +379,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]])
// CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
// CHECK-NEXT: "tf.IdentityN"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
@ -399,11 +399,72 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%b = "tf.B"() : () -> tensor<i32>
%c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
%e = "tf.E"(%c, %a) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%e:2 = "tf.IdentityN"(%c, %a) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @side_effect_middle
func @side_effect_middle() {
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @side_effect_head_no_operand
func @side_effect_head_no_operand() {
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Const"
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%c = "tf.C"() {_xla_outside_compilation = "cluster1"} : () -> tensor<i32>
"tf.D"(%c) : (tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @side_effect_tail_no_operand
func @side_effect_tail_no_operand() {
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: "tf.Const"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK: "tf_device.launch"()
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( {
%a = "tf.A"() : () -> tensor<i32>
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
%cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
}

View File

@ -110,6 +110,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
pm.addPass(TF::CreateResourceDeviceInferencePass());
pm.addPass(TFDevice::CreateClusterOutliningPass());
pm.addPass(CreateTPUDynamicPaddingMapperPass());
pm.addPass(CreateTPUResourceReadForWritePass());
pm.addPass(CreateTPUShardingIdentificationPass());
pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
pm.addPass(CreateTPURewritePass());

View File

@ -629,8 +629,7 @@ class Lower_UnaryOpsComposition
LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op,
PatternRewriter &rewriter) const override {
Value result = op.x();
for (StringRef op_name :
op.op_names().getAsRange<StringAttr, StringRef>()) {
for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
std::string full_name = "tf." + op_name.str();
// All ops in the sequences have the same result type as the original
// result type.

View File

@ -287,6 +287,10 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass();
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources
// the cluster only writes to.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass();
// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime
// ops.
std::unique_ptr<OperationPass<ModuleOp>> CreateTPURewritePass();

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
@ -34,6 +35,7 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@ -118,7 +120,10 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
// computation or other ops that can be extracted, and have no operands from
// other ops in the TPU computation that cannot be extracted.
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
const TF::SideEffectAnalysis& side_effect_analysis,
tf_device::ClusterOp cluster) {
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
cluster.getParentOfType<FuncOp>());
Region* cluster_region = &cluster.body();
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
@ -127,6 +132,15 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
// An outside compiled op can be extracted if its operands are not from
// other ops in the cluster that cannot be extracted.
// Check if the side effecting op right before this side effecting op, if
// it is side effecting, can be head extracted. Because of op ordering due
// to side effects, if this is not true, this op cannot be head extracted.
auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
if (!predecessors.empty() &&
!head_outside_compiled_ops.contains(predecessors.back()))
continue;
auto walk_result = cluster_op.walk([&](Operation* op) {
for (Value operand : op->getOperands()) {
Operation* operand_op = GetOpOfValue(operand);
@ -168,11 +182,11 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
// Extracts and move outside compiled ops that have no dependencies in the
// cluster to before the cluster.
mlir::LogicalResult LiftHeadOutsideCompiledOps(
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
tf_device::ClusterOp cluster, std::string* host_device,
bool* cluster_updated) {
OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
std::string* host_device, bool* cluster_updated) {
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
FindOutsideCompiledOpsAtHead(cluster);
FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
if (head_outside_compiled_ops.empty()) return success();
if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
host_device)))
@ -191,9 +205,12 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps(
// TPU computation or other ops that can be extracted, and have no results used
// by other ops in the TPU computation that cannot be extracted.
void FindOutsideCompiledOpsAtTailAndClusterResults(
const TF::SideEffectAnalysis& side_effect_analysis,
tf_device::ClusterOp cluster,
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
llvm::SmallVectorImpl<Value>* cluster_results) {
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
cluster.getParentOfType<FuncOp>());
Region* cluster_region = &cluster.body();
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
Operation* terminator = cluster.GetBody().getTerminator();
@ -205,6 +222,15 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
for (Operation& cluster_op : cluster_ops) {
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
// Check if the side effecting op right after this side effecting op, if
// it is side effecting, can be tail extracted. Because of op ordering due
// to side effects, if this is not true, this op cannot be tail extracted.
auto successors = analysis.DirectControlSuccessors(
&cluster_op, [&terminator](Operation* op) { return op != terminator; });
if (!successors.empty() &&
!tail_outside_compiled_ops_set.contains(successors.front()))
continue;
llvm::SmallVector<int, 4> results_to_forward;
bool can_be_extracted =
llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
@ -293,13 +319,14 @@ tf_device::ClusterOp UpdateClusterResults(
// Extracts and move outside compiled ops that do not create dependencies in the
// cluster to after the cluster.
mlir::LogicalResult LiftTailOutsideCompiledOps(
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
std::string host_device, tf_device::ClusterOp* cluster,
bool* cluster_updated) {
OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
const mlir::TF::RuntimeDevices& devices, std::string host_device,
tf_device::ClusterOp* cluster, bool* cluster_updated) {
llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
llvm::SmallVector<Value, 4> cluster_results;
FindOutsideCompiledOpsAtTailAndClusterResults(
*cluster, &tail_outside_compiled_ops, &cluster_results);
FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
&tail_outside_compiled_ops,
&cluster_results);
if (tail_outside_compiled_ops.empty()) return success();
if (host_device.empty())
@ -365,6 +392,7 @@ struct TPUExtractHeadTailOutsideCompilation
};
void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
@ -379,10 +407,12 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
for (tf_device::ClusterOp cluster : clusters) {
std::string host_device;
bool cluster_updated = false;
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster,
&host_device, &cluster_updated)) ||
failed(LiftTailOutsideCompiledOps(&builder, devices, host_device,
&cluster, &cluster_updated)))
if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
devices, cluster, &host_device,
&cluster_updated)) ||
failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
devices, host_device, &cluster,
&cluster_updated)))
return signalPassFailure();
if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
}

View File

@ -0,0 +1,140 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
namespace TFTPU {
// A pass that finds TPU clusters with write only resource access and adds an
// associated resource read, so the resource can later be fused into TPUExecute.
namespace {
struct TPUResourceReadForWrite
: public PassWrapper<TPUResourceReadForWrite, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
// Helper struct holding a resource value and its associated type.
struct ResourceValueAndSubtype {
Value resource;
Type subtype;
};
// Finds resource handle and type for result if result writes to a resource.
ResourceValueAndSubtype GetResourceWriteResult(
tf_device::ClusterFuncOp cluster_func, Value result) {
ResourceValueAndSubtype resource;
if (!result.hasOneUse()) return resource;
Operation* result_user = *result.getUsers().begin();
auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
if (!assign_var) return resource;
auto handle = assign_var.resource();
// Skip result if cluster writes to the same variable via multiple results.
for (Operation* handle_user : handle.getUsers()) {
if (handle_user == assign_var) continue;
auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
if (!assign_var_user) continue;
if (assign_var_user.value().getDefiningOp() == cluster_func)
return resource;
}
resource.resource = assign_var.resource();
resource.subtype = assign_var.value().getType();
return resource;
}
// Checks if resource is read by TPU cluster.
bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
Value resource) {
for (Operation* resource_user : resource.getUsers())
if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
for (Operation* read_user : read.value().getUsers())
if (read_user == cluster_func) return true;
return false;
}
void TPUResourceReadForWrite::runOnOperation() {
SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
cluster_funcs.push_back(cluster_func);
});
OpBuilder builder(&getContext());
// Add resource reads for resource writes from TPU cluster where for such
// resources the TPU cluster does not read from.
for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
builder.setInsertionPoint(cluster_func);
SmallVector<Value, 4> read_operands;
for (Value result : cluster_func.getResults()) {
// TODO(lyandy): Update pass to use resource alias analysis.
auto resource_and_type = GetResourceWriteResult(cluster_func, result);
if (!resource_and_type.resource) continue;
if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
continue;
auto new_read = builder.create<TF::ReadVariableOp>(
resource_and_type.resource.getLoc(), resource_and_type.subtype,
resource_and_type.resource);
read_operands.push_back(new_read.value());
}
if (read_operands.empty()) continue;
// Update caller and function types with new read operands.
auto operands = llvm::to_vector<4>(cluster_func.getOperands());
operands.append(read_operands.begin(), read_operands.end());
auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
cluster_func.getAttrs());
cluster_func.replaceAllUsesWith(new_cluster_func);
FuncOp func = cluster_func.getFunc();
Block& block = func.front();
for (Value read_operand : read_operands)
block.addArgument(read_operand.getType());
func.setType(FunctionType::get(block.getArgumentTypes(),
func.getCallableResults(), &getContext()));
cluster_func.erase();
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
return std::make_unique<TPUResourceReadForWrite>();
}
static PassRegistration<TPUResourceReadForWrite> pass(
"tf-tpu-resource-read-for-write",
"Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes "
"with no reads");
} // namespace TFTPU
} // namespace mlir

View File

@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
@ -80,46 +81,14 @@ constexpr char kInvalidExecutorGraphMsg[] =
constexpr char kDeviceAttr[] = "tf.device";
constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
bool IsLegalChar(char c, bool first_char) {
if (isalpha(c)) return true;
if (isdigit(c)) return true;
if (c == '.') return true;
if (c == '_') return true;
// First character of a node name can only be a letter, digit, dot or
// underscore.
if (first_char) return false;
if (c == '/') return true;
if (c == '-') return true;
return false;
}
// Convert characters in name that are considered illegal in TensorFlow Node
// name to '.'.
std::string LegalizeNodeName(llvm::StringRef name) {
assert(!name.empty() && "expected non-empty name");
std::string legalized_name;
bool first = true;
for (auto c : name) {
if (IsLegalChar(c, first)) {
legalized_name += c;
} else {
legalized_name += '.';
}
first = false;
}
return legalized_name;
}
// OpOrArgLocNameMapper that legalizes the returned name.
class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
private:
std::string GetName(OpOrVal op_or_val) override {
return LegalizeNodeName(OpOrArgLocNameMapper::GetName(op_or_val));
std::string name = OpOrArgLocNameMapper::GetName(op_or_val);
assert(!name.empty() && "expected non-empty name");
mlir::LegalizeNodeName(name);
return name;
}
};
@ -523,13 +492,14 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
if (index >= num_data_results) break;
// TODO(jpienaar): If there is a result index specified, ensure only one
// and that it matches the result index of the op.
std::string orig_name(output_names[index]);
auto tensor_id = ParseTensorName(orig_name);
auto name = LegalizeNodeName(
llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
std::string name(output_names[index]);
auto tensor_id = ParseTensorName(name);
std::string tensor_id_node(tensor_id.node());
assert(!tensor_id_node.empty() && "expected non-empty name");
mlir::LegalizeNodeName(tensor_id_node);
// Ensure name does not get reused.
(void)exporter.op_to_name_.GetUniqueName(name);
(void)exporter.op_to_name_.GetUniqueName(tensor_id_node);
}
}
@ -537,8 +507,9 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
TF_RET_CHECK(input_names.size() == block.getNumArguments());
for (const auto& it : llvm::enumerate(function.getArguments())) {
// TODO(lyandy): Update when changing feed/fetch import.
std::string orig_name(input_names[it.index()]);
std::string name = LegalizeNodeName(orig_name);
std::string name(input_names[it.index()]);
assert(!name.empty() && "expected non-empty name");
mlir::LegalizeNodeName(name);
auto tensor_id = ParseTensorName(name);
TF_RET_CHECK(tensor_id.index() == 0)
<< "input port designation not supported";

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_
#include "absl/types/span.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
template <typename T>
inline llvm::ArrayRef<T> SpanToArrayRef(absl::Span<const T> span) {
return llvm::ArrayRef<T>(span.data(), span.size());
}
template <typename T>
inline llvm::ArrayRef<T> SpanToArrayRef(absl::Span<T> span) {
return llvm::ArrayRef<T>(span.data(), span.size());
}
template <typename T>
inline llvm::MutableArrayRef<T> SpanToMutableArrayRef(absl::Span<T> span) {
return llvm::MutableArrayRef<T>(span.data(), span.size());
}
template <typename T>
inline absl::Span<const T> ArrayRefToSpan(llvm::ArrayRef<T> ref) {
return absl::Span<const T>(ref.data(), ref.size());
}
template <typename T>
inline absl::Span<T> MutableArrayRefToSpan(llvm::MutableArrayRef<T> ref) {
return absl::Span<T>(ref.data(), ref.size());
}
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include <cctype>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/IR/Identifier.h" // from @llvm-project
namespace mlir {
namespace {
// Checks if a character is legal for a TensorFlow node name, with special
// handling if a character is at the beginning.
bool IsLegalChar(char c, bool first_char) {
if (isalpha(c)) return true;
if (isdigit(c)) return true;
if (c == '.') return true;
if (c == '_') return true;
// First character of a node name can only be a letter, digit, dot or
// underscore.
if (first_char) return false;
if (c == '/') return true;
if (c == '-') return true;
return false;
}
} // anonymous namespace
void LegalizeNodeName(std::string& name) {
if (name.empty()) return;
if (!IsLegalChar(name[0], /*first_char=*/true)) name[0] = '.';
for (char& c : llvm::drop_begin(name, 1))
if (!IsLegalChar(c, /*first_char=*/false)) c = '.';
}
std::string GetNameFromLoc(Location loc) {
llvm::SmallVector<llvm::StringRef, 8> loc_names;
llvm::SmallVector<Location, 8> locs;
locs.push_back(loc);
bool names_is_nonempty = false;
while (!locs.empty()) {
Location curr_loc = locs.pop_back_val();
if (auto name_loc = curr_loc.dyn_cast<NameLoc>()) {
// Add name in NameLoc. For NameLoc we also account for names due to ops
// in functions where the op's name is first.
auto name = name_loc.getName().strref().split('@').first;
loc_names.push_back(name);
if (!name.empty()) names_is_nonempty = true;
continue;
} else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>()) {
// Add name if CallSiteLoc's callee has a NameLoc (as should be the
// case if imported with DebugInfo).
if (auto name_loc = call_loc.getCallee().dyn_cast<NameLoc>()) {
auto name = name_loc.getName().strref().split('@').first;
loc_names.push_back(name);
if (!name.empty()) names_is_nonempty = true;
continue;
}
} else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>()) {
// Push all locations in FusedLoc in reverse order, so locations are
// visited based on order in FusedLoc.
auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations());
locs.append(reversed_fused_locs.begin(), reversed_fused_locs.end());
continue;
}
// Location is not a supported, so an empty StringRef is added.
loc_names.push_back(llvm::StringRef());
}
if (names_is_nonempty)
return llvm::join(loc_names.begin(), loc_names.end(), ";");
return "";
}
} // namespace mlir

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_
#include <string>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Location.h" // from @llvm-project
namespace mlir {
// Converts characters in name that are considered illegal in TensorFlow Node
// name to '.'.
void LegalizeNodeName(std::string& name);
// Creates a TensorFlow node name from a location.
std::string GetNameFromLoc(Location loc);
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
inline absl::string_view StringRefToView(llvm::StringRef ref) {
return absl::string_view(ref.data(), ref.size());
}
inline llvm::StringRef StringViewToRef(absl::string_view view) {
return llvm::StringRef(view.data(), view.size());
}
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_

View File

@ -238,7 +238,6 @@ cc_library(
deps = [
":type_to_shape",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/tf2xla:common",
@ -389,7 +388,6 @@ cc_library(
":xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/hlo:legalize_control_flow",
"//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",

Some files were not shown because too many files have changed in this diff Show More