Merge remote-tracking branch 'upstream/master' into arc_mli_evaltensor_porting_conv
This commit is contained in:
commit
ae7944f9c0
@ -45,6 +45,11 @@
|
||||
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
directly.
|
||||
* 16 bits quantization
|
||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
|
@ -199,6 +199,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":logging",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
@ -138,8 +139,8 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are legacy features in TF Eager Runtime.
|
||||
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||
// Following are features in current TF Eager Runtime.
|
||||
// TODO(tfrt-devs): Figure out a way to deprecate following features after
|
||||
// migrated to TFRT.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Clear pending nodes in thread executors and kernel caches.
|
||||
@ -157,6 +158,22 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
// Update the Eager Executor for current thread.
|
||||
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are helper functions to assist integrating TFRT with current
|
||||
// TF eager runtime.
|
||||
// TODO(b/172877902): These helper functions are currently used to support
|
||||
// PyFuncOp on TFRT, and might be useful for ops that directly use low
|
||||
// level TF APIs. Remove/replace the following functions when TFRT native
|
||||
// ops are implemented.
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Create an abstract tensor handle from tensorflow::Tensor.
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
|
||||
tensorflow::Tensor& t, const char* d_name) = 0;
|
||||
|
||||
// Convert a TFRT TensorHandle to tensorflow::TensorHandle.
|
||||
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) = 0;
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
@ -76,6 +76,7 @@ cc_library(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -328,6 +328,17 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
|
||||
return Execute(context, inputs, operation_name, attributes,
|
||||
expected_output_shapes, status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
const std::vector<PartialTensorShape>& expected_output_shapes,
|
||||
TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
@ -344,7 +355,7 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
}
|
||||
device_thread->StartExecute(context, operation_name,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
expected_output_shapes.size());
|
||||
}
|
||||
StatusPtr first_bad_status(nullptr);
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
@ -386,8 +397,15 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (expected_output_shapes[i].IsFullyDefined()) {
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components),
|
||||
absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
|
||||
status));
|
||||
} else {
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
@ -396,9 +414,27 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
|
||||
TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return std::unique_ptr<ParallelTensor>(
|
||||
new ParallelTensor(parallel_device, std::move(components), shape, dtype));
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
std::vector<int64> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
@ -406,11 +442,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
// Verify that the TensorHandle's shape matches all of the component shapes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
int64 tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
@ -419,17 +454,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
return FromTensorHandles(parallel_device, std::move(components),
|
||||
absl::Span<const int64>(shape), status);
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -93,6 +94,15 @@ class ParallelDevice {
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Accepts inferred shapes for outputs, which if fully defined will avoid
|
||||
// querying the shapes of the underlying TensorHandles. This allows async
|
||||
// computation to continue without blocking.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
const std::vector<PartialTensorShape>& expected_output_shapes,
|
||||
TF_Status* status) const;
|
||||
|
||||
private:
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
@ -117,10 +127,15 @@ class ParallelDevice {
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
// devices of a ParallelDevice. Inspects `components` to determine a shape.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
// Uses the provided shape without additional checks, which avoids blocking.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
@ -132,10 +147,10 @@ class ParallelTensor {
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
absl::Span<const int64> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
shape_(shape.begin(), shape.end()),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
|
@ -80,5 +80,41 @@ TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::vector<std::string> devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
ParallelDevice parallel_device(std::move(devices));
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
|
||||
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
auto outputs = parallel_device.Execute(
|
||||
context.get(), std::vector<ParallelTensor*>(), "VarHandleOp",
|
||||
TFE_OpGetAttrs(handle_op.get()), {PartialTensorShape({})}, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
|
||||
EXPECT_EQ(0, handles[0]->shape().size());
|
||||
}
|
||||
|
||||
} // namespace parallel_device
|
||||
} // namespace tensorflow
|
||||
|
@ -508,7 +508,7 @@ cc_library(
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"//tensorflow/compiler/tf2xla:mlir_bridge_pass",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos);
|
||||
constant_arg_indices, inputs, variable_infos, dev);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
switch (stage) {
|
||||
|
@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
|
||||
may_alias_resource_update;
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
||||
variable_infos);
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constants, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
return cache->Compile(options, function, *args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
|
@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kConstantResource:
|
||||
signature.arg_values.push_back(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
@ -288,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
|
||||
bool use_mlir = config &&
|
||||
GetMlirBridgeRolloutPolicy(*config) ==
|
||||
GetMlirBridgeRolloutPolicy(*graph, *config) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledByUser &&
|
||||
node_def.op() != "VarIsInitializedOp";
|
||||
#ifdef LIBTPU_ON_GCE
|
||||
|
@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
|
||||
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_input_indices, inputs, variable_infos);
|
||||
constant_input_indices, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -23,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
@ -89,10 +89,21 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
// Make sure that kernels have been registered on the JIT device.
|
||||
XlaOpRegistry::RegisterCompilationKernels();
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
// Only check for compilability if the MLIR bridge is not enabled.
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(absl::nullopt);
|
||||
if (policy == MlirBridgeRolloutPolicy::kDisabledByUser ||
|
||||
policy == MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
|
||||
absl::optional<ConfigProto> config_proto;
|
||||
if (flr->config_proto()) {
|
||||
config_proto = *flr->config_proto();
|
||||
}
|
||||
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
@ -121,15 +132,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
}
|
||||
}
|
||||
|
||||
// Get function body, constant args, and resource args.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
|
||||
|
||||
MemoryTypeVector input_memory_types =
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args) {
|
||||
absl::Span<VariableInfo const> variable_args, Device* device) {
|
||||
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
||||
std::vector<XlaCompiler::Argument> out;
|
||||
out.resize(inputs.size());
|
||||
|
||||
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
|
||||
DeviceContext* device_context = nullptr;
|
||||
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
|
||||
bool using_default_context = false;
|
||||
auto cleanup = xla::MakeCleanup([&] {
|
||||
if (device_context != nullptr && !using_default_context) {
|
||||
device_context->Unref();
|
||||
}
|
||||
});
|
||||
if (device_context == nullptr) {
|
||||
using_default_context = true;
|
||||
auto* dev_info = device->tensorflow_gpu_device_info();
|
||||
if (dev_info) device_context = dev_info->default_context;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
CHECK(!info.var() || info.lock_held())
|
||||
@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const Tensor* input = inputs[input_num];
|
||||
|
||||
XlaCompiler::Argument& arg = out[input_num];
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
// Handles compile-time constants.
|
||||
|
||||
// TODO(b/157241314): Support constants located in resource variables.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
|
||||
<< "tf2xla bridge does not support must-be-constants located in "
|
||||
"resource variables; try moving them to a tensor";
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else if (variable_info_lookup.count(input_num)) {
|
||||
if (variable_info_lookup.count(input_num)) {
|
||||
// Handles resource variables.
|
||||
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
Tensor value_on_host(value->dtype(), value->shape());
|
||||
if (!device_context) {
|
||||
value_on_host = *value;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
|
||||
value, "", device, &value_on_host));
|
||||
}
|
||||
arg.kind = XlaCompiler::Argument::kConstantResource;
|
||||
arg.constant_value = value_on_host;
|
||||
}
|
||||
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else {
|
||||
// Normal inputs.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||
|
@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
|
||||
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args);
|
||||
absl::Span<VariableInfo const> variable_args,
|
||||
Device* device);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||
|
@ -3,7 +3,11 @@
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_binary",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -126,12 +130,14 @@ cc_library(
|
||||
srcs = ["mlir_graph_optimization_pass.cc"],
|
||||
hdrs = ["mlir_graph_optimization_pass.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -198,11 +204,22 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mlir_graph_optimization_pass_test",
|
||||
srcs = ["mlir_graph_optimization_pass_test.cc"],
|
||||
deps = [
|
||||
":mlir_graph_optimization_pass",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "litfiles",
|
||||
srcs = glob(["runlit*py"]),
|
||||
|
@ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result,
|
||||
return alloc;
|
||||
}
|
||||
|
||||
/// Converts the results of the operation `op` to memref types and append them
|
||||
/// to the `results` vector.
|
||||
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
for (auto result : llvm::enumerate(op->getResults())) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) return failure();
|
||||
|
||||
if (resultType.hasStaticShape()) {
|
||||
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
continue;
|
||||
}
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
results.push_back(
|
||||
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
||||
results_shape[result.index()], &rewriter));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename HloOpTy>
|
||||
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
public:
|
||||
@ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
HloOpTy hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status =
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(
|
||||
@ -139,28 +144,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
if (failed(
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
|
||||
// TODO(silvasean): Move this helper to MLIR core.
|
||||
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
||||
@ -194,8 +179,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
|
||||
Value transformed_operand =
|
||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
rewriter.create<lmhlo::BroadcastInDimOp>(
|
||||
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
||||
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
|
||||
|
||||
rewriter.replaceOp(op, {resultBuffer});
|
||||
|
||||
@ -211,48 +195,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
auto operand_rank = operand_type.getRank();
|
||||
|
||||
SmallVector<Value, 2> sizes, strides;
|
||||
sizes.reserve(operand_shape.size());
|
||||
strides.reserve(operand_shape.size());
|
||||
auto result_type = op.getType().cast<RankedTensorType>();
|
||||
auto result_rank = result_type.getRank();
|
||||
|
||||
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||
Value broadcast_dim_value =
|
||||
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
|
||||
Value result_dim_size = b->create<ExtractElementOp>(
|
||||
loc, op.output_dimensions(), broadcast_dim_value);
|
||||
Value operand_dim_size =
|
||||
ShapedType::isDynamic(operand_shape[dim.index()])
|
||||
? b->create<DimOp>(loc, operand, dim.index()).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
|
||||
.getResult();
|
||||
|
||||
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
|
||||
// tensor<index> for `output_dimensions` as well.
|
||||
// Compute a reversed scan product. Compute the stride for the dimensions so
|
||||
// far, working from minor to major dimensions. Additionally, save the
|
||||
// operand shape Values to use in the next loop.
|
||||
SmallVector<Value, 2> operand_strides(operand_rank, one);
|
||||
SmallVector<Value, 2> operand_sizes(operand_rank, one);
|
||||
Value stride_so_far = one;
|
||||
for (int i = operand_rank - 1; i >= 0; --i) {
|
||||
Value operand_dim_size =
|
||||
ShapedType::isDynamic(operand_shape[i])
|
||||
? b->create<DimOp>(loc, operand, i).getResult()
|
||||
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
|
||||
operand_sizes[i] = operand_dim_size;
|
||||
|
||||
operand_strides[i] = stride_so_far;
|
||||
if (i > 0) {
|
||||
stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value, 2> sizes, strides;
|
||||
sizes.reserve(result_rank);
|
||||
strides.reserve(result_rank);
|
||||
|
||||
DenseMap<int, int> output_to_input_dim;
|
||||
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
|
||||
}
|
||||
for (int i = 0; i < result_rank; ++i) {
|
||||
Value i_val = b->create<ConstantIndexOp>(loc, i);
|
||||
Value result_dim_size =
|
||||
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
|
||||
if (!result_dim_size.getType().isIndex()) {
|
||||
result_dim_size =
|
||||
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
|
||||
}
|
||||
sizes.push_back(result_dim_size);
|
||||
|
||||
auto it = output_to_input_dim.find(i);
|
||||
// If the rank of the output is greater than the rank of the input, i.e.
|
||||
// there was no output dimension in the inverse broadcast_dimensions map
|
||||
// we also set stride to 0 to emulate padding of the shape with 1s and the
|
||||
// corresponding expansion.
|
||||
if (it == output_to_input_dim.end()) {
|
||||
strides.push_back(zero);
|
||||
continue;
|
||||
}
|
||||
|
||||
// There can be two cases:
|
||||
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
|
||||
// 1) Operand dim == result dim => expansion is not needed
|
||||
// => stride flattened buffer stride
|
||||
// 2) Operand dim < result dim => expansion is needed => stride := 0.
|
||||
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
|
||||
operand_dim_size, result_dim_size);
|
||||
strides.push_back(
|
||||
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
|
||||
|
||||
// Size of input dim can be set to the size of the corresponding output
|
||||
// dimension for both cases.
|
||||
sizes.push_back(result_dim_size);
|
||||
int dim = it->second;
|
||||
Value is_expansion = b->create<CmpIOp>(
|
||||
loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
|
||||
strides.push_back(b->create<mlir::SelectOp>(loc, is_expansion, zero,
|
||||
operand_strides[dim]));
|
||||
}
|
||||
|
||||
// Type-erased memref type with static rank, dynamic sizes and strides.
|
||||
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
|
||||
SmallVector<int64_t, 2> dynamic_layout(result_rank,
|
||||
MemRefType::kDynamicStrideOrOffset);
|
||||
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
|
||||
SmallVector<int64_t, 2> dynamic_shape(result_rank,
|
||||
MemRefType::kDynamicSize);
|
||||
auto type_erased_memref_type = MemRefType::get(
|
||||
dynamic_shape, operand_type.getElementType(),
|
||||
|
@ -1,4 +1,6 @@
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
|
||||
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
|
||||
// RUN: | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
func @external_func() -> tensor<3xi64>
|
||||
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
|
||||
|
||||
// CHECK-LABEL: func @dyn_broadcast
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
|
||||
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%c1 = constant 1 : i64
|
||||
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
|
||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// CHECK: %[[C0___:.*]] = constant 0 : index
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// CHECK: %[[C1___:.*]] = constant 1 : index
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
|
||||
// CHECK-SAME: offset: [0],
|
||||
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
|
||||
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
|
||||
|
||||
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
// Do not store the value back to avoid the tensor-store being rewritten to
|
||||
// a copy into the pre-allocated argument.
|
||||
return
|
||||
%rank = rank %tensor_result : tensor<?x?x?xf32>
|
||||
return %rank : index
|
||||
}
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
||||
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
||||
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
||||
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
||||
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
|
||||
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
||||
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
@ -645,7 +620,7 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
||||
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
|
||||
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
||||
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
||||
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
||||
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
|
||||
shape.assuming_yield %7 : tensor<?xf16>
|
||||
|
@ -390,6 +390,7 @@ cc_library(
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
"transforms/generated_prepare_tf.inc",
|
||||
"transforms/insert_call_once_op.cc",
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/legalize_tf_while.cc",
|
||||
"transforms/lower_static_tensor_list.cc",
|
||||
|
@ -453,6 +453,11 @@ class Translator {
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Build call once operator.
|
||||
BufferOffset<tflite::Operator> BuildCallOnceOperator(
|
||||
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
@ -787,6 +792,22 @@ BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
|
||||
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
|
||||
int init_subgraph_index =
|
||||
subgraph_index_map_.at(op.session_init_function().str());
|
||||
auto builtin_options =
|
||||
tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
|
||||
auto inputs = builder_.CreateVector(operands);
|
||||
auto outputs = builder_.CreateVector(results);
|
||||
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||
tflite::BuiltinOptions_CallOnceOptions,
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
@ -1026,6 +1047,12 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
|
||||
if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
|
||||
return BuildCallOnceOperator(initOp, operands, results);
|
||||
}
|
||||
}
|
||||
|
||||
std::string op_name = inst->getName().getStringRef().str();
|
||||
uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
|
||||
|
||||
|
@ -448,13 +448,54 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
// Gets a constant splat for the given value of type. Requires value to be of
|
||||
// type static shaped RankedTensorType. `unique_index` is used to get the unique
|
||||
// value for the attribute.
|
||||
static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
|
||||
OpBuilder builder) {
|
||||
mlir::Type element_ty = getElementTypeOrSelf(type);
|
||||
|
||||
if (element_ty.isSignlessInteger())
|
||||
return DenseElementsAttr::get(
|
||||
type, builder.getIntegerAttr(element_ty, unique_index));
|
||||
|
||||
if (element_ty.isa<mlir::FloatType>())
|
||||
return DenseElementsAttr::get(
|
||||
type, builder.getFloatAttr(element_ty, unique_index));
|
||||
|
||||
if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
|
||||
mlir::RankedTensorType new_type =
|
||||
RankedTensorType::get(type.getShape(), qtype.getStorageType());
|
||||
return DenseElementsAttr::get(
|
||||
new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
|
||||
}
|
||||
llvm_unreachable("unhandled element type");
|
||||
}
|
||||
|
||||
// TODO(b/172664358): Creates a new op instead of reusing constant op.
|
||||
// Creates a constant op to represent stateful variable. The function static
|
||||
// variable `stateful_variable_idx` is used as a unique value for each constant
|
||||
// to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
|
||||
// is the ShapedType for the const op.
|
||||
Operation* BuildVariableOp(const tflite::TensorT& tensor,
|
||||
mlir::RankedTensorType shaped_type,
|
||||
OpBuilder builder, Location loc) {
|
||||
static int stateful_variable_idx = 0;
|
||||
mlir::ElementsAttr value =
|
||||
GetSplat(shaped_type, stateful_variable_idx++, builder);
|
||||
if (IsQuantized(tensor)) {
|
||||
auto op = builder.create<tfl::QConstOp>(
|
||||
loc, mlir::TypeAttr::get(shaped_type), value);
|
||||
return op.getOperation();
|
||||
}
|
||||
auto op = builder.create<tfl::ConstOp>(loc, value);
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
const std::vector<uint8_t>& buffer,
|
||||
OpBuilder builder, Location loc) {
|
||||
if (buffer.empty()) {
|
||||
return errors::InvalidArgument("Constant's buffer may not be empty");
|
||||
}
|
||||
|
||||
bool is_variable, OpBuilder builder,
|
||||
Location loc) {
|
||||
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
@ -466,7 +507,9 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
auto elem_type = shaped_type.getElementType();
|
||||
|
||||
mlir::ElementsAttr value;
|
||||
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||
if (is_variable) {
|
||||
return BuildVariableOp(tensor, shaped_type, builder, loc);
|
||||
} else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ConvertFloatBuffer(shaped_type, float_type, buffer));
|
||||
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
|
||||
@ -846,19 +889,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
GetTensorIndices(subgraph, ordered_input_arrays));
|
||||
}
|
||||
|
||||
// Add state variables to inputs.
|
||||
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
|
||||
func_inputs.end());
|
||||
for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
|
||||
auto& tensor = *subgraph.tensors.at(i);
|
||||
if (tensor.is_variable && !input_index_set.contains(i)) {
|
||||
func_inputs.emplace_back(i);
|
||||
input_index_set.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input_or_variable : func_inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input_or_variable);
|
||||
for (int input : func_inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input);
|
||||
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
||||
// but we cannot differentiate scalars from unranked tensors.
|
||||
// Here we reverse the default assumption that shape = [] means unranked.
|
||||
@ -889,7 +921,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
const bool is_func_input = input_index_set.contains(output);
|
||||
const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
|
||||
output) != func_inputs.end();
|
||||
bool is_constant = !is_op_output[output] && !is_func_input;
|
||||
// There are 2 cases tensor is scalar when it doesn't have a shape in
|
||||
// flatbuffer:
|
||||
@ -991,7 +1024,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
|
||||
op_builder, const_loc)
|
||||
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
const_tensor.is_variable, op_builder, const_loc);
|
||||
if (!op_or_err.ok()) {
|
||||
return emitError(const_loc, op_or_err.status().ToString()),
|
||||
op_or_err.status();
|
||||
@ -1051,7 +1084,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
|
||||
op_builder, const_loc)
|
||||
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
const_tensor.is_variable, op_builder, const_loc);
|
||||
if (!op_or_err.ok()) {
|
||||
return emitError(const_loc, op_or_err.status().ToString()),
|
||||
op_or_err.status();
|
||||
|
@ -1972,6 +1972,43 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
return value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1);
|
||||
// For now, only supports cast between integer types.
|
||||
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!elements_attr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto result_element_type =
|
||||
getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
|
||||
auto operand_element_type = input()
|
||||
.getType()
|
||||
.cast<ShapedType>()
|
||||
.getElementType()
|
||||
.dyn_cast<IntegerType>();
|
||||
// Returns nullptr if either result/operand element type is not integer.
|
||||
if (!result_element_type || !operand_element_type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const bool is_input_unsigned = operand_element_type.isUnsigned();
|
||||
const int output_bitwidth = result_element_type.getWidth();
|
||||
// The integer cast op is the same as C integer cast. Depends on the operand
|
||||
// type's signedness, we will determine whether or not sign extension is
|
||||
// needed.
|
||||
auto cast = [&](APInt value) {
|
||||
return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
|
||||
: value.sextOrTrunc(output_bitwidth);
|
||||
};
|
||||
|
||||
return elements_attr.mapValues(result_element_type, cast);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
|
||||
TFL_I32Tensor:$begin,
|
||||
TFL_I32Tensor:$end,
|
||||
TFL_I32Tensor:$strides,
|
||||
@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -3443,6 +3443,8 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
let hasOptions = 0;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
@ -4358,6 +4360,21 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TFL_CallOnceOp : TFL_Op<"call_once", []> {
|
||||
let summary = "Invokes an initialization function";
|
||||
|
||||
let description = [{
|
||||
This operation invokes the given initialization function for the session
|
||||
initializer in tf saved model dialect.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$session_init_function
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TFL_CustomOp : Op<TFL_Dialect, "custom", [
|
||||
NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Custom op";
|
||||
|
@ -52,6 +52,12 @@ struct QuantizationSpecs {
|
||||
// weight FakeQuant).
|
||||
bool disable_per_channel = false;
|
||||
|
||||
// When set to true, the fixed output ranges of the activation ops (tanh,
|
||||
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
|
||||
// emulation ops should be specified after the ops in the input graph. This
|
||||
// flag should be set to false for post-training quantization.
|
||||
bool disable_enforced_fixed_output_range = false;
|
||||
|
||||
// The node type when the model is exported. Currently this is limited to
|
||||
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
|
||||
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,
|
||||
|
@ -587,3 +587,55 @@ func @rsqrt_bf16() -> tensor<bf16> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_i64_to_i32
|
||||
func @cast_i64_to_i32() -> tensor<5xi32> {
|
||||
%cst = constant dense<[-1, 0, 1, 2147483647, 2147483648]> : tensor<5xi64>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<5xi64>) -> tensor<5xi32>
|
||||
return %0 : tensor<5xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-1, 0, 1, 2147483647, -2147483648]> : tensor<5xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_i32_to_ui8
|
||||
func @cast_i32_to_ui8() -> tensor<6xui8> {
|
||||
%cst = constant dense<[0, -1, 256, 127, -128, -129]> : tensor<6xi32>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<6xi32>) -> tensor<6xui8>
|
||||
return %0 : tensor<6xui8>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, 255, 0, 127, 128, 127]> : tensor<6xui8>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_ui8_to_i8
|
||||
func @cast_ui8_to_i8() -> tensor<4xi8> {
|
||||
%cst = constant dense<[0, 255, 127, 128]> : tensor<4xui8>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi8>
|
||||
return %0 : tensor<4xi8>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, -1, 127, -128]> : tensor<4xi8>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_i8_to_i32
|
||||
func @cast_i8_to_i32() -> tensor<4xi32> {
|
||||
%cst = constant dense<[0, 128, -1, -128]> : tensor<4xi8>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, -128, -1, -128]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_ui8_to_i32
|
||||
func @cast_ui8_to_i32() -> tensor<4xi32> {
|
||||
%cst = constant dense<[0, 128, 129, 255]> : tensor<4xui8>
|
||||
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, 128, 129, 255]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
|
||||
|
@ -411,11 +411,11 @@ versions {
|
||||
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @cond_false_10{{.+}}is_stateless = true{{.+}}then_branch = @cond_true_10
|
||||
# CHECK: "tf.If"{{.+}}else_branch = @cond_false0{{.+}}is_stateless = false{{.+}}then_branch = @cond_true0
|
||||
# CHECK: func @cond_false_10
|
||||
# CHECK: func private @cond_false_10
|
||||
# CHECK-NEXT: tfl.div
|
||||
# CHECK: func @cond_true_10
|
||||
# CHECK: func private @cond_true_10
|
||||
# CHECK-NEXT: tfl.sub
|
||||
# CHECK: func @cond_false0
|
||||
# CHECK: func private @cond_false0
|
||||
# CHECK-NEXT: tfl.mul
|
||||
# CHECK: func @cond_true0
|
||||
# CHECK: func private @cond_true0
|
||||
# CHECK-NEXT: tfl.add
|
||||
|
@ -78,14 +78,14 @@ versions {
|
||||
}
|
||||
|
||||
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
|
||||
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_5:.*]] = constant unit
|
||||
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
# CHECK-DAG: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
|
||||
# CHECK-DAG: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
|
||||
# CHECK-DAG: %[[VAL_5:.*]] = constant unit
|
||||
# CHECK-DAG: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
|
||||
# CHECK-DAG: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
|
||||
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||
|
@ -8,9 +8,11 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
|
||||
return %24 : tensor<1x4xf32>
|
||||
// CHECK-LABEL: main
|
||||
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
|
||||
// CHECK: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
|
||||
// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: return %[[RES0]]
|
||||
// CHECK: return %[[RES2]]
|
||||
|
||||
}
|
||||
|
||||
|
@ -5,8 +5,8 @@
|
||||
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %0 : tensor<*xf32>
|
||||
// CHECK: func @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s
|
||||
|
||||
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -1026,11 +1026,11 @@ func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%arg0: t
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
|
||||
|
||||
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -2160,11 +2160,11 @@ func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_As
|
||||
|
||||
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<?x1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>
|
||||
|
||||
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
@ -3190,7 +3190,7 @@ func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
|
||||
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
|
||||
// CHECK: return %0 : tensor<?x!tf.string>
|
||||
|
||||
@ -3213,7 +3213,7 @@ func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) ->
|
||||
// CHECK: return %0 : tensor<?x!tf.string>
|
||||
// CHECK: }
|
||||
|
||||
func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
|
||||
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -3330,12 +3330,12 @@ func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name
|
||||
%71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64>
|
||||
return %68, %71, %64 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3345,12 +3345,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3359,12 +3359,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
|
||||
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %4 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3374,12 +3374,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_Assert
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3389,12 +3389,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3403,12 +3403,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
|
||||
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %4 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3418,12 +3418,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_Asse
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
|
||||
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
|
||||
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
%2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
@ -3433,12 +3433,12 @@ func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_
|
||||
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %5 : tensor<i1>
|
||||
}
|
||||
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
|
||||
|
||||
|
||||
func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
%0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
|
||||
%1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
|
||||
%2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
|
||||
@ -3448,6 +3448,6 @@ func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "va
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
|
||||
// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor<?x!tf.string>, tensor<?xi64>) -> tensor<?x10xf64>
|
||||
// CHECK: return %0 : tensor<?x10xf64>
|
||||
|
40
tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir
Normal file
40
tensorflow/compiler/mlir/lite/tests/insert_call_once_op.mlir
Normal file
@ -0,0 +1,40 @@
|
||||
// RUN: tf-opt -split-input-file -tfl-insert-call-once-op %s | FileCheck %s
|
||||
|
||||
// Tests that new call_once op is added when there is a session initializer.
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> ()
|
||||
|
||||
func @init_all_tables()
|
||||
attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
|
||||
%cst = constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%cst_0 = constant dense<["a", "b", "c", "d"]> : tensor<4x!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
|
||||
"tf.LookupTableImportV2"(%0, %cst, %cst_0) {device = ""} : (tensor<!tf.resource>, tensor<4xi64>, tensor<4x!tf.string>) -> ()
|
||||
return
|
||||
// CHECK-LABEL: @init_all_tables
|
||||
}
|
||||
|
||||
func @serving_default(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<*x!tf.string> {tf_saved_model.index_path = ["r"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "hash_table_Lookup/LookupTableFindV2:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%cst = constant dense<"f"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
|
||||
%1 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor<!tf.resource>, tensor<i64>, tensor<!tf.string>) -> tensor<*x!tf.string>
|
||||
return %1 : tensor<*x!tf.string>
|
||||
// CHECK-LABEL: @serving_default
|
||||
// CHECK: "tfl.call_once"() {session_init_function = "init_all_tables"} : () -> ()
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that no call_once op is added.
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
func @no_call_once(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<i64> {tf_saved_model.index_path = ["r"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
return %arg0 : tensor<i64>
|
||||
// CHECK-LABEL: no_call_once
|
||||
// CHECK-NOT: "tfl.call_once"
|
||||
}
|
||||
}
|
@ -1122,6 +1122,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
return %0 : tensor<1x2x2x5x!tf.string>
|
||||
// CHECK-LABEL: strided_slice_with_string
|
||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
|
@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
|
||||
return %0 : tensor<1x2x2x5x!tf.quint8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testStridedSliceWithString
|
||||
func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
return %0 : tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
|
||||
|
@ -407,16 +407,16 @@ func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d
|
||||
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x4x4x2xf32>) -> tensor<1x4x4x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
|
||||
%cst1 = constant dense<2.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<3.0> : tensor<112x2xf32>
|
||||
%cst2 = constant dense<[[3.1, 3.2], [3.1, 3.2], [3.1, 3.2], [3.1, 3.2]]> : tensor<4x2xf32>
|
||||
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32>
|
||||
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x4x4x2xf32>, tensor<4x2xf32>) -> tensor<1x4x4x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
return %1 : tensor<1x4x4x2xf32>
|
||||
|
||||
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %1 = "tfl.mul"(%0, %cst_1)
|
||||
@ -484,17 +484,17 @@ func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs
|
||||
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<[[2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3]]> : tensor<4x4xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x37xf32>, tensor<4x37xf32>, none) -> (tensor<4x4xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
|
||||
// CHECK: %[[unit:.*]] = constant unit
|
||||
// CHECK: %[[filter:.*]] = constant dense<2.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[filter:.*]] = constant dense<{{.*}}> : tensor<4x4xf32>
|
||||
// CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]])
|
||||
// CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]]
|
||||
// CHECK: return %[[add_result]]
|
||||
@ -851,17 +851,17 @@ func @fuseDivIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoConv2d_Scalar
|
||||
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x1xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]> : tensor<1x2x2x2xf32>
|
||||
%cst1 = constant dense<1.0> : tensor<2xf32>
|
||||
%cst1 = constant dense<1.0> : tensor<1xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<f32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<f32>) -> tensor<1x112x112x2xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x1xf32>, tensor<f32>) -> tensor<1x112x112x1xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
return %1 : tensor<1x112x112x1xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00], [6.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+01, 1.200000e+01], [1.400000e+01, 1.600000e+01]]]]> : tensor<1x2x2x2xf32>
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
@ -1397,3 +1397,33 @@ func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x
|
||||
// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithSplat2D
|
||||
func @FuseFullyConnectedAddWithSplat2D(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[BIAS:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[BIAS]])
|
||||
// CHECK: return %[[FC_RESULT]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseMulIntoConv2d_Splat2D
|
||||
func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
|
||||
%cst0 = constant dense<[[[[1.0, 2.0]]], [[[3.0, 4.0]]]]> : tensor<2x1x1x2xf32>
|
||||
%cst1 = constant dense<1.0> : tensor<2xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<1x112x112x2xf32>
|
||||
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
|
||||
|
||||
return %1 : tensor<1x112x112x2xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00]]], {{\[\[\[}}6.000000e+00, 8.000000e+00]]]]> : tensor<2x1x1x2xf32>
|
||||
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
|
||||
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
@ -77,3 +77,32 @@ func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tens
|
||||
// CHECK-NEXT: return %[[softmax]], %[[argmax]] : tensor<128x16xf32>, tensor<128xi32>
|
||||
return %2, %3 : tensor<128x16xf32>, tensor<128xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: PruneUnusedLstm
|
||||
func @PruneUnusedLstm(%arg0: tensor<1x28x28xf32>) -> (tensor<1x28x28xf32>) {
|
||||
%input = "tfl.quantize"(%arg0) {qtype = tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>} : (tensor<1x28x28xf32>) -> tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>
|
||||
%cst_1 = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi8>} : () -> tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cst_2 = constant unit
|
||||
%cst_3 = "tfl.pseudo_qconst"() {qtype = tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x20xi8>} : () -> tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cst_7 = "tfl.pseudo_qconst"() {qtype = tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20xi8>} : () -> tensor<20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cst_11 = "tfl.pseudo_qconst"() {qtype = tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x28xi8>} : () -> tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
%cell_input = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi6>} : () -> tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%input,
|
||||
%cst_11, %cst_11, %cst_11, %cst_11,
|
||||
%cst_3, %cst_3, %cst_3, %cst_3,
|
||||
%cst_2, %cst_2, %cst_2,
|
||||
%cst_7, %cst_7, %cst_7, %cst_7,
|
||||
%cst_2, %cst_2,
|
||||
%cst_1, %cell_input,
|
||||
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
|
||||
: ( tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>,
|
||||
tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>,
|
||||
tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>,
|
||||
none, none, none,
|
||||
tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>,
|
||||
none, none,
|
||||
tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>,
|
||||
none, none, none, none) -> tensor<1x28x20x!quant.uniform<i8:f32, 0.006:-34>>
|
||||
return %arg0 : tensor<1x28x28xf32>
|
||||
// CHECK-NEXT: return %arg0
|
||||
}
|
||||
|
@ -166,3 +166,37 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>)
|
||||
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
|
||||
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeLstmCellInput
|
||||
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
|
||||
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
|
||||
%cst_2 = constant unit
|
||||
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
|
||||
%cst_7 = constant dense<1.0> : tensor<20xf32>
|
||||
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
|
||||
%cell_input = constant dense<0.0> : tensor<1x20xf32>
|
||||
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
|
||||
%cst_11, %cst_11, %cst_11, %cst_11,
|
||||
%cst_3, %cst_3, %cst_3, %cst_3,
|
||||
%cst_2, %cst_2, %cst_2,
|
||||
%cst_7, %cst_7, %cst_7, %cst_7,
|
||||
%cst_2, %cst_2,
|
||||
%cst_1, %cell_stats,
|
||||
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
|
||||
: ( tensor<1x28x28xf32>,
|
||||
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
|
||||
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
|
||||
none, none, none,
|
||||
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
|
||||
none, none,
|
||||
tensor<1x20xf32>, tensor<1x20xf32>,
|
||||
none, none, none, none) -> tensor<1x28x20xf32>
|
||||
return %0 : tensor<1x28x20xf32>
|
||||
// CHECK: %[[none:.*]] = constant unit
|
||||
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
|
||||
// Checks if input 19 is correctly passed from a dequantize op.
|
||||
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
|
||||
}
|
||||
|
@ -30,9 +30,9 @@ func @while() -> tensor<1xf32>
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @WhileOp_cond(
|
||||
// CHECK-LABEL: func private @WhileOp_cond(
|
||||
// CHECK: tfl.greater
|
||||
// CHECK-LABEL: func @WhileOp_body(
|
||||
// CHECK-LABEL: func private @WhileOp_body(
|
||||
// CHECK: tfl.sub
|
||||
// CHECK: tfl.add
|
||||
|
||||
@ -63,21 +63,21 @@ func @while2(%cst : tensor<i32>) -> tensor<1xf32> attributes {tf.entry_function
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
func private @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) attributes {sym_visibility = "private"} {
|
||||
func private @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) {
|
||||
%0 = "tfl.sub"(%arg0, %arg2) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0, %1, %arg2 : tensor<*xi32>, tensor<*xf32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @WhileOp_cond(
|
||||
// CHECK-LABEL: func private @WhileOp_cond(
|
||||
// CHECK: tfl.greater
|
||||
// CHECK-LABEL: func @WhileOp_body(
|
||||
// CHECK-LABEL: func private @WhileOp_body(
|
||||
// CHECK: tfl.sub
|
||||
// CHECK: tfl.add
|
||||
|
||||
@ -152,14 +152,14 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
|
||||
// CHECK: tfl.yield
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @tfl.while_cond(
|
||||
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
// CHECK-LABEL: func private @tfl.while_cond(
|
||||
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<i1>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @tfl.while_body(
|
||||
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} {
|
||||
// CHECK-LABEL: func private @tfl.while_body(
|
||||
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) {
|
||||
// CHECK: [[VAL_91:%.*]] = "tfl.cast"
|
||||
// CHECK: return
|
||||
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>
|
||||
|
@ -234,6 +234,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// tf.variable to model this.
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFL::CreateSplitMergedOperandsPass());
|
||||
|
||||
// Add CallOnceOp when there is a session initializer function in tf saved
|
||||
// model dialect.
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// This pass inserts a TFL::CallOnce op when tf_saved_model's session
|
||||
// initializer is given.
|
||||
class InsertCallOnceOpFromSessionInitializerPass
|
||||
: public mlir::PassWrapper<InsertCallOnceOpFromSessionInitializerPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
tf_saved_model::SessionInitializerOp session_init_op =
|
||||
tf_saved_model::GetSessionInitializerOp(module);
|
||||
|
||||
if (!session_init_op) return;
|
||||
|
||||
SymbolTable symbol_table(module);
|
||||
|
||||
for (auto sym_ref : session_init_op.initializers()) {
|
||||
FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
|
||||
sym_ref.cast<FlatSymbolRefAttr>().getValue());
|
||||
|
||||
if (!init_func_op) {
|
||||
module.emitError("no session initializer function found");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
auto dict_attr =
|
||||
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
if (!dict_attr) continue;
|
||||
|
||||
OpBuilder builder(func.getContext());
|
||||
builder.setInsertionPointToStart(&func.getBlocks().front());
|
||||
builder.create<TFL::CallOnceOp>(func.getLoc(), init_func_op.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Inserts a TFL::CallOnce op when tf_saved_model's session initializer is
|
||||
// given.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateInsertCallOnceOpFromSessionInitializerPass() {
|
||||
return std::make_unique<InsertCallOnceOpFromSessionInitializerPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<InsertCallOnceOpFromSessionInitializerPass> pass(
|
||||
"tfl-insert-call-once-op",
|
||||
"Insert CallOnce op when tf_saved_model's session initializer is given");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -27,11 +27,14 @@ limitations under the License.
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
@ -729,6 +732,143 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
}
|
||||
};
|
||||
|
||||
// If the operand to a broadcastable op is a splat constant, try to replace it
|
||||
// with a 0-d constant, e.g. before this optimization,
|
||||
// %cst = constant dense<1.0> : tensor<16x16x4xf32>
|
||||
// %0 = "tfl.conv_2d"...
|
||||
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
|
||||
// After this optimization:
|
||||
// %cst = constant dense<1.0> : tensor<f32>
|
||||
// %0 = "tfl.conv_2d"...
|
||||
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
|
||||
// This pattern can enable more fusing opportunities when the binary op is
|
||||
// following conv ops.
|
||||
template <typename BinaryOpType>
|
||||
struct ScalarizeSplatConstantForBroadcastableOps
|
||||
: public OpRewritePattern<BinaryOpType> {
|
||||
using OpRewritePattern<BinaryOpType>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(BinaryOpType binary_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
DenseElementsAttr splat_elements_attr;
|
||||
if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
constexpr int kSplatOperandIndex = 1;
|
||||
auto result_type =
|
||||
binary_op.getResult().getType().template cast<ShapedType>();
|
||||
mlir::Value non_splat_operand =
|
||||
binary_op.getOperand(1 - kSplatOperandIndex);
|
||||
auto non_splat_operand_type =
|
||||
non_splat_operand.getType().cast<ShapedType>();
|
||||
// If the other operand's shape does not equal to the result shape, then we
|
||||
// cannot scalarize the splat constant because the result shape relies on
|
||||
// the splat constant op's shape for broadcasting.
|
||||
if (!non_splat_operand_type.hasStaticShape() ||
|
||||
non_splat_operand_type.getShape() != result_type.getShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If non-splat operand is not fusable affine ops, then no need to apply
|
||||
// this transformation.
|
||||
if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Creates a new scalar constant op using the splat value.
|
||||
mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
|
||||
auto scalar_elements_attr = DenseElementsAttr::get(
|
||||
RankedTensorType::get({},
|
||||
splat_elements_attr.getType().getElementType()),
|
||||
splat_elements_attr.getSplatValue());
|
||||
|
||||
auto scalar_constant_op = rewriter.create<ConstantOp>(
|
||||
splat_operand.getLoc(), scalar_elements_attr.getType(),
|
||||
scalar_elements_attr);
|
||||
|
||||
binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns true if this value is a splat constant op which can be scalarized.
|
||||
// Also returns the elements attr if this value is indeed a splat constant.
|
||||
bool IsScalarizableSplatConstant(mlir::Value value,
|
||||
DenseElementsAttr *elements_attr) const {
|
||||
if (!matchPattern(value, m_Constant(elements_attr))) {
|
||||
return false;
|
||||
}
|
||||
auto element_type = value.getType().cast<ShapedType>().getElementType();
|
||||
// Ignore per-axis quantized constants because after converting to scalar,
|
||||
// we will lose per-axis qantization parameter.
|
||||
if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
|
||||
return false;
|
||||
}
|
||||
if (IsScalar(value)) {
|
||||
return false;
|
||||
}
|
||||
return elements_attr->isSplat();
|
||||
}
|
||||
|
||||
// If this type is a scalar shaped type.
|
||||
bool IsScalar(mlir::Value value) const {
|
||||
auto type = value.getType().dyn_cast<ShapedType>();
|
||||
if (!type) {
|
||||
return false;
|
||||
}
|
||||
if (!type.hasStaticShape()) {
|
||||
return false;
|
||||
}
|
||||
return type.getNumElements() == 1;
|
||||
}
|
||||
|
||||
// Returns true if we can fuse an affine op with consuming binary op.
|
||||
bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
|
||||
if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
|
||||
TFL::FullyConnectedOp>(affine_op)) {
|
||||
return false;
|
||||
}
|
||||
DenseElementsAttr value;
|
||||
// Check that bias are constants if not none.
|
||||
Value bias = affine_op->getOperand(2);
|
||||
if (!bias.getType().isa<NoneType>() &&
|
||||
!matchPattern(bias, m_Constant(&value))) {
|
||||
return false;
|
||||
}
|
||||
// If the binary op is mul/div, also check that filter is constant.
|
||||
if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
|
||||
!matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We can only fuse F32/BF16.
|
||||
auto is_fusable_type = [](Type t) {
|
||||
Type element_type = t;
|
||||
if (auto shaped_type = t.dyn_cast<ShapedType>()) {
|
||||
element_type = shaped_type.getElementType();
|
||||
}
|
||||
return element_type.isBF16() || element_type.isF32();
|
||||
};
|
||||
for (Type t : binary_op->getOperandTypes()) {
|
||||
if (!is_fusable_type(t)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarizeSplatConstantForSub =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
|
||||
using ScalarizeSplatConstantForAdd =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
|
||||
using ScalarizeSplatConstantForMul =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
|
||||
using ScalarizeSplatConstantForDiv =
|
||||
ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
|
||||
|
||||
struct ConvertTrivialTransposeOpToReshapeOp
|
||||
: public OpRewritePattern<TFL::TransposeOp> {
|
||||
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
|
||||
@ -818,6 +958,8 @@ void Optimize::runOnFunction() {
|
||||
OwningRewritePatternList phase_2_patterns;
|
||||
TFL::populateWithGenerated(ctx, phase_2_patterns);
|
||||
phase_2_patterns.insert<
|
||||
ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
|
||||
ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
|
||||
FuseFullyConnectedAndAdd, FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
|
||||
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
|
||||
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
|
||||
|
@ -94,6 +94,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseCustomOpsPass();
|
||||
|
||||
// Inserts an TFL::CallOnce op when the tf_saved_model's session initialzer is
|
||||
// given.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateInsertCallOnceOpFromSessionInitializerPass();
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -139,6 +139,30 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Removes LSTMs that have dangling output.
|
||||
// LSTMs are not removed automatically becuase they are stateful ops.
|
||||
template <typename LstmOpTy>
|
||||
struct PruneUnusedLstm : public OpRewritePattern<LstmOpTy> {
|
||||
public:
|
||||
explicit PruneUnusedLstm(MLIRContext* context)
|
||||
: OpRewritePattern<LstmOpTy>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(LstmOpTy lstm_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Operation* op = lstm_op.getOperation();
|
||||
if (op->isKnownTerminator()) {
|
||||
return failure();
|
||||
}
|
||||
for (auto result : op->getOpResults()) {
|
||||
if (!result.use_empty()) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
|
||||
|
||||
void PostQuantizePass::runOnFunction() {
|
||||
@ -147,6 +171,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, patterns);
|
||||
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
|
||||
patterns.insert<PruneUnusedLstm<TFL::UnidirectionalSequenceLSTMOp>>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on TFLite dialect.
|
||||
#include <cmath>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
@ -21,10 +22,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
|
||||
@ -305,6 +309,52 @@ bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
|
||||
using PrepareQuantStats =
|
||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
// Calculates the minimum power of two that is not less than the value.
|
||||
double power_of_two_bound(double value) {
|
||||
return std::pow(2, std::ceil(std::log2(value)));
|
||||
}
|
||||
|
||||
// Quantize recurrent input of LSTM with 16 bits.
|
||||
template <typename SourceOp, typename Q, typename DQ>
|
||||
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
|
||||
public:
|
||||
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
|
||||
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
|
||||
LogicalResult matchAndRewrite(SourceOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
op.input_cell_state().getDefiningOp());
|
||||
// Recurrent input is be used within an LSTM, and thus should have one use.
|
||||
if (!stats_op || !stats_op.getResult().hasOneUse()) {
|
||||
return failure();
|
||||
}
|
||||
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
|
||||
if (!stats) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
double max = std::max(
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
|
||||
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
|
||||
double bound = power_of_two_bound(max);
|
||||
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
|
||||
// maximum value is adjusted to get a scale of power_of_two(max)/32768.
|
||||
quant::QuantizedType quant_type = quant::fakeQuantAttrsToType(
|
||||
stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0,
|
||||
/*narrow_range*/ false, expressed, /*is_signed*/ true);
|
||||
|
||||
rewriter.setInsertionPointAfter(stats_op);
|
||||
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
|
||||
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
|
||||
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
using PrepareLstmQuantStats =
|
||||
ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
|
||||
quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext* ctx = func.getContext();
|
||||
@ -326,7 +376,14 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
bool enforce_fixed_output_range = ContainsQuantizeOps(func);
|
||||
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
|
||||
// Enforce fixed output range for post-training quantization and
|
||||
// when the model has quantization emulation ops, unless it was disabled
|
||||
// explicitly by the flag.
|
||||
bool enforced_output_range =
|
||||
(quant_specs_.post_training_quantization ||
|
||||
quantization_aware_training_mode) &&
|
||||
!quant_specs_.disable_enforced_fixed_output_range;
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
@ -337,6 +394,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
patterns.insert<PrepareLstmQuantStats>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
|
||||
SanityCheckAndAdjustment(func);
|
||||
@ -345,8 +403,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(
|
||||
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
||||
GetOpQuantSpec,
|
||||
enforce_fixed_output_range || quant_specs_.post_training_quantization);
|
||||
GetOpQuantSpec, enforced_output_range);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
|
||||
}
|
||||
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
absl::optional<ConfigProto> config_proto) {
|
||||
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto) {
|
||||
switch (GetUserRequest(config_proto)) {
|
||||
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
|
||||
return MlirBridgeRolloutPolicy::kEnabledByUser;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -46,6 +47,7 @@ enum class MlirBridgeRolloutPolicy {
|
||||
// The config_proto param is a required input for all TF1 graphs but it is
|
||||
// redundant for TF2 graphs.
|
||||
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph,
|
||||
absl::optional<tensorflow::ConfigProto> config_proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
@ -32,10 +33,19 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
auto* shadow_run_success = monitoring::Counter<0>::New(
|
||||
"/tensorflow/mlir/shadow_run_success", "Success count of MLIR shadow runs");
|
||||
|
||||
auto* shadow_run_failure = monitoring::Counter<2>::New(
|
||||
"/tensorflow/mlir/shadow_run_failure", "Failure count of MLIR shadow runs",
|
||||
"kind", "name");
|
||||
|
||||
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
|
||||
return {ref.data(), ref.size()};
|
||||
}
|
||||
@ -109,7 +119,7 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
// Skip conversion from Graph to MLIR if none of the passes are enabled.
|
||||
const bool is_enabled =
|
||||
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(config_proto);
|
||||
return pass_registration.pass->IsEnabled(config_proto, **graph);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
@ -123,6 +133,17 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
<< "(registered " << registry_->passes().size()
|
||||
<< " passes)";
|
||||
|
||||
// For scenarios when the new bridge is enabled by analysis we need to make
|
||||
// sure that MLIR transformations are executed in a shadow mode.
|
||||
// In this case, no changes should be done to the original `graph`
|
||||
// and no failures propagated to the user.
|
||||
bool enabled_by_analysis =
|
||||
mlir_rollout_policy_(**graph, config_proto) ==
|
||||
MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis;
|
||||
if (enabled_by_analysis) {
|
||||
LOG_FIRST_N(INFO, 1) << "Shadow run of MLIR enabled after graph analysis";
|
||||
}
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
RegisterDialects(context.getDialectRegistry());
|
||||
@ -130,10 +151,21 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
import_config.upgrade_legacy = true;
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
|
||||
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context);
|
||||
if (!module_ref_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
|
||||
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return module_ref_status.status();
|
||||
}
|
||||
|
||||
auto module_ref = std::move(module_ref_status.ValueOrDie());
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
for (auto& pass_registration : registry_->passes()) {
|
||||
@ -144,7 +176,17 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
|
||||
auto pass_status =
|
||||
pass_registration.pass->Run(config_proto, *module_ref, **graph);
|
||||
if (!pass_status.ok()) {
|
||||
if (enabled_by_analysis) {
|
||||
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
|
||||
// Do not fail, let the old bridge to run on the original `graph`.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return pass_status;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
|
||||
@ -153,6 +195,25 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
|
||||
GraphExportConfig export_config;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
|
||||
// In case MLIR is enabled by analysis, verify that MLIR could be converted
|
||||
// back to TF graph. Original `graph` must stay the same.
|
||||
if (enabled_by_analysis) {
|
||||
auto empty_graph = std::make_unique<Graph>(OpRegistry::Global());
|
||||
FunctionLibraryDefinition empty_flib = empty_graph->flib_def();
|
||||
|
||||
auto mlir_to_graph_status =
|
||||
ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
|
||||
&empty_flib, &control_ret_nodes);
|
||||
if (mlir_to_graph_status.ok()) {
|
||||
shadow_run_success->GetCell()->IncrementBy(1);
|
||||
} else {
|
||||
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
@ -183,7 +244,7 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
const bool is_enabled =
|
||||
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
|
||||
return pass_registration.pass->IsEnabled(
|
||||
options.session_options->config);
|
||||
options.session_options->config, **options.graph);
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
@ -34,10 +37,11 @@ class MlirOptimizationPass {
|
||||
public:
|
||||
virtual ~MlirOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
|
||||
virtual Status Run(const ConfigProto& config_proto,
|
||||
mlir::ModuleOp module) = 0;
|
||||
virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph) = 0;
|
||||
};
|
||||
|
||||
class MlirOptimizationPassRegistry {
|
||||
@ -59,10 +63,14 @@ class MlirOptimizationPassRegistry {
|
||||
// Returns the global registry of MLIR optimization passes.
|
||||
static MlirOptimizationPassRegistry& Global();
|
||||
|
||||
// Register optimization `pass` with the given `priority`.
|
||||
void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
|
||||
passes_.insert({priority, std::move(pass)});
|
||||
}
|
||||
|
||||
// Free the memory allocated for all passes.
|
||||
void ClearPasses() { passes_.clear(); }
|
||||
|
||||
const Passes& passes() const { return passes_; }
|
||||
|
||||
private:
|
||||
@ -75,8 +83,11 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
public:
|
||||
explicit MlirFunctionOptimizationPass(
|
||||
const MlirOptimizationPassRegistry* registry =
|
||||
&MlirOptimizationPassRegistry::Global())
|
||||
: registry_(registry) {}
|
||||
&MlirOptimizationPassRegistry::Global(),
|
||||
std::function<MlirBridgeRolloutPolicy(const Graph& graph,
|
||||
absl::optional<ConfigProto>)>
|
||||
mlir_rollout_policy = GetMlirBridgeRolloutPolicy)
|
||||
: registry_(registry), mlir_rollout_policy_(mlir_rollout_policy) {}
|
||||
|
||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
@ -85,6 +96,9 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
|
||||
|
||||
private:
|
||||
const MlirOptimizationPassRegistry* registry_;
|
||||
std::function<MlirBridgeRolloutPolicy(
|
||||
const tensorflow::Graph& graph, absl::optional<tensorflow::ConfigProto>)>
|
||||
mlir_rollout_policy_;
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
@ -100,7 +114,8 @@ class MlirV1CompatOptimizationPass {
|
||||
public:
|
||||
virtual ~MlirV1CompatOptimizationPass() = default;
|
||||
virtual llvm::StringRef name() const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
|
||||
virtual bool IsEnabled(const ConfigProto& config_proto,
|
||||
const Graph& graph) const = 0;
|
||||
|
||||
virtual Status Run(const GraphOptimizationPassOptions& options,
|
||||
mlir::ModuleOp module) = 0;
|
||||
|
121
tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc
Normal file
121
tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ::testing::_;
|
||||
using ::testing::NiceMock;
|
||||
using ::testing::Return;
|
||||
using ::testing::Test;
|
||||
|
||||
class MockMlirOptimizationPass : public MlirOptimizationPass {
|
||||
public:
|
||||
MOCK_METHOD(llvm::StringRef, name, (), (const, override));
|
||||
MOCK_METHOD(bool, IsEnabled,
|
||||
(const ConfigProto& config_proto, const Graph& graph),
|
||||
(const, override));
|
||||
MOCK_METHOD(Status, Run,
|
||||
(const ConfigProto& config_proto, mlir::ModuleOp module,
|
||||
const Graph& graph),
|
||||
(override));
|
||||
};
|
||||
|
||||
class MlirGraphOptimizationPassTest : public Test {
|
||||
public:
|
||||
void Init(MlirBridgeRolloutPolicy rollout_policy, Status pass_run_result) {
|
||||
graph_ = std::make_unique<Graph>(OpRegistry::Global());
|
||||
|
||||
function_optimization_pass_ = MlirFunctionOptimizationPass(
|
||||
&MlirOptimizationPassRegistry::Global(),
|
||||
[rollout_policy](const Graph& graph, absl::optional<ConfigProto>) {
|
||||
return rollout_policy;
|
||||
});
|
||||
|
||||
auto optimization_pass =
|
||||
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
|
||||
|
||||
EXPECT_CALL(*optimization_pass, IsEnabled(_, _))
|
||||
.WillRepeatedly(Return(true));
|
||||
EXPECT_CALL(*optimization_pass, Run(_, _, _))
|
||||
.WillOnce(Return(pass_run_result));
|
||||
MlirOptimizationPassRegistry::Global().Add(0, std::move(optimization_pass));
|
||||
|
||||
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
MlirOptimizationPassRegistry::Global().ClearPasses();
|
||||
}
|
||||
|
||||
ConfigProto config_proto_;
|
||||
MlirFunctionOptimizationPass function_optimization_pass_;
|
||||
DeviceSet device_set_;
|
||||
std::unique_ptr<Graph> graph_;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_;
|
||||
std::vector<std::string> control_ret_node_names_;
|
||||
bool control_rets_updated_{false};
|
||||
};
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
|
||||
Init(MlirBridgeRolloutPolicy::kEnabledByUser,
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
|
||||
EXPECT_EQ(function_optimization_pass_.Run(
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
// Proto matchers might be unavailable.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
|
||||
Init(MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis,
|
||||
Status(error::Code::ABORTED, "aborted"));
|
||||
|
||||
GraphDef original_graph_def;
|
||||
graph_->ToGraphDef(&original_graph_def);
|
||||
|
||||
EXPECT_EQ(function_optimization_pass_.Run(
|
||||
device_set_, config_proto_, &graph_, flib_.get(),
|
||||
&control_ret_node_names_, &control_rets_updated_),
|
||||
Status::OK());
|
||||
|
||||
// Proto matchers might be unavailable.
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
GraphDef resulted_graph_def;
|
||||
graph_->ToGraphDef(&resulted_graph_def);
|
||||
EXPECT_THAT(resulted_graph_def,
|
||||
::testing::proto::IgnoringRepeatedFieldOrdering(
|
||||
::testing::EquivToProto(original_graph_def)));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -604,6 +604,29 @@ If `condition` evaluates to false, print the list of tensors in `data`.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_AssignOp : TF_Op<"Assign", [NoSideEffect]> {
|
||||
let summary = "Update 'ref' by assigning 'value' to it.";
|
||||
|
||||
let description = [{
|
||||
This operation outputs "ref" after the assignment is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$ref,
|
||||
TF_Tensor:$value,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$validate_shape,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$use_locking
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output_ref
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AssignAddVariableOp : TF_Op<"AssignAddVariableOp", []> {
|
||||
let summary = "Adds a value to the current value of a variable.";
|
||||
|
||||
@ -12629,6 +12652,8 @@ retained with length 1.
|
||||
OpBuilderDAG<(ins "Value":$input, "Value":$reduction_indices,
|
||||
"BoolAttr":$keep_dims)>
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> {
|
||||
|
@ -684,6 +684,7 @@ body: A function that takes a list of tensors and returns another
|
||||
|
||||
FlatSymbolRefAttr:$cond,
|
||||
FlatSymbolRefAttr:$body,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
@ -696,12 +697,10 @@ body: A function that takes a list of tensors and returns another
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Get the condition function.
|
||||
@ -755,8 +754,9 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
|
@ -1539,6 +1539,21 @@ void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
|
||||
build(builder, result, out_ty, input, reduction_indices, keep_dims);
|
||||
}
|
||||
|
||||
// TODO: Templatize this fold for all reduction ops.
|
||||
OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!input_ty) return {};
|
||||
auto result_ty = getType().template dyn_cast<RankedTensorType>();
|
||||
if (!result_ty) return {};
|
||||
|
||||
// Bypass this op if the result has the same shape and type. This can happen
|
||||
// if the input tensor has size 0 or size 1.
|
||||
if (!keep_dims() && input_ty == result_ty) {
|
||||
return input();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StridedSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2590,14 +2605,6 @@ static LogicalResult Verify(WhileOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileOp canonicalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<DropAttributes<WhileOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileRegionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -82,25 +82,28 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
|
||||
mlir::SymbolTable symbol_table(
|
||||
session_initializer.getParentOfType<ModuleOp>());
|
||||
|
||||
auto init_func_op =
|
||||
symbol_table.lookup<mlir::FuncOp>(session_initializer.initializer());
|
||||
if (!init_func_op)
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function does not exist";
|
||||
for (auto sym_ref : session_initializer.initializers()) {
|
||||
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
|
||||
sym_ref.cast<FlatSymbolRefAttr>().getValue());
|
||||
|
||||
if (!init_func_op.getType().getResults().empty())
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have no output";
|
||||
if (!init_func_op)
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function does not exist";
|
||||
|
||||
auto exported_names = GetExportedNames(init_func_op);
|
||||
if (!init_func_op.getType().getResults().empty())
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have no output";
|
||||
|
||||
if (exported_names.empty())
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should be exported";
|
||||
auto exported_names = GetExportedNames(init_func_op);
|
||||
|
||||
if (exported_names.size() != 1)
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have only one exported names";
|
||||
if (exported_names.empty())
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should be exported";
|
||||
|
||||
if (exported_names.size() != 1)
|
||||
return session_initializer.emitOpError()
|
||||
<< "the initializer function should have only one exported names";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -291,7 +294,11 @@ static LogicalResult VerifySavedModelModule(
|
||||
|
||||
auto is_init = [&session_initializers](mlir::FuncOp func) {
|
||||
if (session_initializers.empty()) return false;
|
||||
return (*session_initializers.begin()).initializer() == func.getName();
|
||||
auto init_syms = (*session_initializers.begin()).initializers();
|
||||
return std::any_of(
|
||||
init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) {
|
||||
return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName();
|
||||
});
|
||||
};
|
||||
|
||||
SymbolTable symbol_table(module);
|
||||
@ -450,22 +457,36 @@ class OptimizeSessionInitializerPattern
|
||||
LogicalResult matchAndRewrite(SessionInitializerOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SymbolTable symbol_table(op.getParentOfType<ModuleOp>());
|
||||
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(op.initializer());
|
||||
|
||||
// The init function can only be referenced from the SessionInitializerOp.
|
||||
// And there is at most one SessionInitializerOp in the module. So if both
|
||||
// ops have no other uses or have one NoOp only, they can be simply erased.
|
||||
auto &operations = init_func_op.front().getOperations();
|
||||
if ((operations.size() == 1 && operations.front().isKnownTerminator()) ||
|
||||
(operations.size() == 2 &&
|
||||
dyn_cast<mlir::TF::NoOp>(operations.front()) &&
|
||||
operations.back().isKnownTerminator())) {
|
||||
rewriter.eraseOp(init_func_op);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
SmallVector<FuncOp, 2> to_remove;
|
||||
SmallVector<mlir::Attribute, 2> to_keep;
|
||||
for (auto sym_ref : op.initializers()) {
|
||||
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
|
||||
sym_ref.cast<FlatSymbolRefAttr>().getValue());
|
||||
|
||||
// The init function can only be referenced from the SessionInitializerOp.
|
||||
// And there is at most one SessionInitializerOp in the module. So if both
|
||||
// ops have no other uses or have one NoOp only, they can be simply
|
||||
// erased.
|
||||
auto &operations = init_func_op.front().getOperations();
|
||||
if ((operations.size() == 1 && operations.front().isKnownTerminator()) ||
|
||||
(operations.size() == 2 &&
|
||||
dyn_cast<mlir::TF::NoOp>(operations.front()) &&
|
||||
operations.back().isKnownTerminator())) {
|
||||
to_remove.push_back(init_func_op);
|
||||
} else {
|
||||
to_keep.push_back(sym_ref);
|
||||
}
|
||||
}
|
||||
|
||||
return failure();
|
||||
for (auto func_op : to_remove) rewriter.eraseOp(func_op);
|
||||
|
||||
if (to_keep.empty())
|
||||
rewriter.eraseOp(op);
|
||||
else
|
||||
op.setAttr("initializers", rewriter.getArrayAttr(to_keep));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -474,15 +495,22 @@ void SessionInitializerOp::getCanonicalizationPatterns(
|
||||
results.insert<OptimizeSessionInitializerPattern>(context);
|
||||
}
|
||||
|
||||
llvm::Optional<StringRef> GetSessionInitializerExportedName(ModuleOp op) {
|
||||
SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) {
|
||||
auto session_initializer_op = GetSessionInitializerOp(op);
|
||||
if (!session_initializer_op) return llvm::None;
|
||||
if (!session_initializer_op) return {};
|
||||
|
||||
SymbolTable symbol_table(op);
|
||||
auto init_func_op =
|
||||
symbol_table.lookup<mlir::FuncOp>(session_initializer_op.initializer());
|
||||
auto exported_names = GetExportedNames(init_func_op);
|
||||
return exported_names[0];
|
||||
|
||||
SmallVector<StringRef, 2> results;
|
||||
for (auto sym_ref : session_initializer_op.initializers()) {
|
||||
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
|
||||
sym_ref.cast<FlatSymbolRefAttr>().getValue());
|
||||
auto exported_names = GetExportedNames(init_func_op);
|
||||
assert(exported_names.size() == 1);
|
||||
results.push_back(exported_names[0]);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
|
@ -81,7 +81,7 @@ Type GetBoundInputArgTypeFor(mlir::Operation *op);
|
||||
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
|
||||
|
||||
// Returns the exported name for the session initializer function.
|
||||
llvm::Optional<StringRef> GetSessionInitializerExportedName(mlir::ModuleOp op);
|
||||
SmallVector<StringRef, 2> GetSessionInitializerExportedName(mlir::ModuleOp op);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
@ -132,13 +132,13 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
|
||||
def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
|
||||
let summary = "Initializes TensorFlow session state.";
|
||||
let description = [{
|
||||
The session initializer op marks a function that must be called by an
|
||||
external agent exactly once to initialize TensorFlow session state, and this
|
||||
must happen before any other exported functions are called. There must be no
|
||||
more than one session initializer in a saved model.
|
||||
The session initializer op marks one or more functions that must be called
|
||||
by an external agent exactly once to initialize TensorFlow session state,
|
||||
and this must happen before any other exported functions are called. There
|
||||
must be no more than one session initializer op in a saved model.
|
||||
|
||||
The `initializer` represents the initialization function. The function have
|
||||
no output and this function should be only called once.
|
||||
The `initializers` represents the initialization functions. The function
|
||||
have no output and this function should be only called once.
|
||||
|
||||
This is used, for example, to initialize hash tables stored in resources and
|
||||
accessed by resource name (rather than as resource handles or bound inputs
|
||||
@ -146,7 +146,7 @@ def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$initializer
|
||||
SymbolRefArrayAttr:$initializers
|
||||
);
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ def TfSavedModel_AssetOp: TfSavedModel_Op<"asset", [Symbol]> {
|
||||
let description = [{
|
||||
Represents an asset in the saved model that points to an external file. It
|
||||
is a scalar string tensor and it is passed as an argument to the session
|
||||
initializer function.
|
||||
initializer functions.
|
||||
|
||||
The `sym_name` represents the symbol table name used for internal IR
|
||||
references.
|
||||
|
@ -1259,24 +1259,6 @@ func @testIfDropOutputShapes(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// Check that output_shapes attribute is removed for tf.Whileß
|
||||
func @testWhileCond(tensor<*xf32>) -> (tensor<i1>)
|
||||
func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>)
|
||||
// CHECK-LABEL: func @testWhileDropOutputShapes
|
||||
func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
^bb0(%arg0: tensor<*xf32>):
|
||||
// CHECK: "tf.While"
|
||||
// CHECK-NOT: output_shapes
|
||||
%1 = "tf.While"(%arg0) {
|
||||
cond = @testWhileCond,
|
||||
body = @testWhileBody,
|
||||
is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]
|
||||
} : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testNMSV3ToNMSV4
|
||||
func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
|
||||
%max_size = constant dense<2> : tensor<i32>
|
||||
@ -1291,3 +1273,10 @@ func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<
|
||||
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> )
|
||||
return %0#0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testSumFoldBypass
|
||||
func @testSumFoldBypass(%arg0: tensor<4x?xf16>, %arg1: tensor<*xi64>) -> tensor<4x?xf16> {
|
||||
// CHECK: return %arg0
|
||||
%0 = "tf.Sum"(%arg0, %arg1) { keep_dims = false }: (tensor<4x?xf16>, tensor<*xi64>) -> tensor<4x?xf16>
|
||||
return %0 : tensor<4x?xf16>
|
||||
}
|
||||
|
@ -24,9 +24,8 @@ func @single_cluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK: func @[[CLUSTER]]
|
||||
// CHECK: func private @[[CLUSTER]]
|
||||
// CHECK-SAME: (%[[CLUSTER_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK-SAME: sym_visibility = "private"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_ARG_0]])
|
||||
// CHECK: return %[[B_OUTPUT]]
|
||||
|
||||
@ -67,12 +66,12 @@ func @multiple_clusters(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK: func @[[CLUSTER_0]]
|
||||
// CHECK: func private @[[CLUSTER_0]]
|
||||
// CHECK-SAME: (%[[CLUSTER_0_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_0_ARG_0]])
|
||||
// CHECK: return %[[B_OUTPUT]]
|
||||
|
||||
// CHECK: func @[[CLUSTER_1]]
|
||||
// CHECK: func private @[[CLUSTER_1]]
|
||||
// CHECK-SAME: (%[[CLUSTER_1_ARG_0:[a-z0-9]*]]: tensor<?xi32>, %[[CLUSTER_1_ARG_1:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[CLUSTER_1_ARG_0]])
|
||||
// CHECK: %[[F_OUTPUT:[0-9]*]] = "tf.F"(%[[CLUSTER_1_ARG_1]], %[[E_OUTPUT]])
|
||||
@ -98,7 +97,7 @@ func @cluster_operands(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK: func @[[CLUSTER]]
|
||||
// CHECK: func private @[[CLUSTER]]
|
||||
// CHECK-SAME: () -> tensor<?xi32>
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"()
|
||||
// CHECK: return %[[A_OUTPUT]]
|
||||
|
@ -47,11 +47,11 @@ func @func2(%arg0 : tensor<i1>) -> tensor<i1> {
|
||||
|
||||
// CHECK: module
|
||||
// CHECK-SAME: @_tpu_v1_compat_outlined
|
||||
// CHECK-LABEL: func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> tensor<i1>
|
||||
// CHECK-LABEL: func nested @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf.TPUReplicateMetadata
|
||||
// CHECK-NEXT: tf.opA
|
||||
|
||||
// CHECK-LABEL: func @_tpu_v1_compat_outlined_func1(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<i1>, tensor<i32>)
|
||||
// CHECK-LABEL: func nested @_tpu_v1_compat_outlined_func1(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<i1>, tensor<i32>)
|
||||
// CHECK-NEXT: tf.TPUReplicateMetadata
|
||||
// CHECK-NEXT: tf.opA
|
||||
// CHECK-NEXT: tf.opA
|
||||
|
@ -27,14 +27,14 @@ func @foo() {
|
||||
|
||||
|
||||
// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
|
||||
// CHECK: func @[[FUNCTIONALIZE_FUNC]]
|
||||
// CHECK: func private @[[FUNCTIONALIZE_FUNC]]
|
||||
// CHECK: "tf.If"
|
||||
// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
|
||||
// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
|
||||
|
||||
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
|
||||
|
||||
// CHECK: func @[[ELSE_FUNC]]
|
||||
// CHECK: func private @[[ELSE_FUNC]]
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: func @[[THEN_FUNC]]
|
||||
// CHECK: func private @[[THEN_FUNC]]
|
||||
// CHECK: "tf.Mul"
|
||||
|
@ -40,7 +40,7 @@ library {
|
||||
}
|
||||
}
|
||||
# Drop the control dependency on arg for the node "test"
|
||||
# CHECK-LABEL: func @foo
|
||||
# CHECK-LABEL: func private @foo
|
||||
# CHECK: tf_executor.island wraps "tf.Const"()
|
||||
node_def {
|
||||
name: "test"
|
||||
|
@ -80,6 +80,6 @@ versions {
|
||||
# CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]]
|
||||
|
||||
# Verify that callee has the unit attribute tf._input_shapes.
|
||||
# CHECK: func @[[FUNCTION]]
|
||||
# CHECK: func private @[[FUNCTION]]
|
||||
# CHECK: attributes
|
||||
# CHECK-SAME: tf._input_shapes{{[,}]}}
|
||||
|
@ -90,6 +90,6 @@ library {
|
||||
# CHECK: tf.HashTableV2
|
||||
# CHECK-SAME: shared_name = "hash_table_node"
|
||||
|
||||
# CHECK: func @create_resource
|
||||
# CHECK: func private @create_resource
|
||||
# CHECK: tf.HashTableV2
|
||||
# CHECK-SAME: shared_name = "hash_table_node@create_resource"
|
||||
|
@ -49,5 +49,5 @@ library {
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
||||
# CHECK-DAG: func private @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func private @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
||||
|
@ -13,7 +13,7 @@
|
||||
# CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island wraps "tf.StatefulPartitionedCall"
|
||||
# CHECK-SAME: f = @[[FUNC:[a-z0-9]*]]
|
||||
# CHECK: tf_executor.fetch %[[ISLAND_1]], %[[ISLAND_2]] : tensor<*xf32>, tensor<*xf32>
|
||||
# CHECK: func @[[FUNC]](%arg0: tensor<*xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
# CHECK: func private @[[FUNC]](%arg0: tensor<*xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
|
||||
node {
|
||||
name: "args_0"
|
||||
|
@ -55,4 +55,4 @@ versions {
|
||||
# site (a numerical suffix may be appended).
|
||||
|
||||
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo0}
|
||||
# CHECK: func @foo0
|
||||
# CHECK: func private @foo0
|
||||
|
@ -74,7 +74,7 @@ library {
|
||||
}
|
||||
# The attribute "experimental_ints_on_device" and the return type INT32
|
||||
# ensure that kDeviceRetOp is used instead of kRetOp
|
||||
# CHECK-LABEL: func @foo
|
||||
# CHECK-LABEL: func private @foo
|
||||
# CHECK: tf.experimental_ints_on_device = true
|
||||
# CHECK: return %{{.*}} tensor<{{.*}}i32>
|
||||
attr {
|
||||
|
@ -5,8 +5,8 @@
|
||||
# Verify that the NameAttrList is properly turned into reference to functions on import
|
||||
# CHECK: tf.Case
|
||||
# CHECK-SAME: branches = [@[[FOO:[a-z0-9]+]], @[[BAR:[a-z0-9]+]]]
|
||||
# CHECK-DAG: func @[[FOO]]()
|
||||
# CHECK-DAG: func @[[BAR]]()
|
||||
# CHECK-DAG: func private @[[FOO]]()
|
||||
# CHECK-DAG: func private @[[BAR]]()
|
||||
|
||||
node {
|
||||
name: "predicate"
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Verify that the _input_shapes attribute of the FunctionDef is respected.
|
||||
# This also checks that the output type is correctly inferred based on
|
||||
# that.
|
||||
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
|
||||
#CHECK: func private @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
|
||||
|
||||
node {
|
||||
name: "Placeholder"
|
||||
|
@ -124,5 +124,5 @@ versions {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo110}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo111}
|
||||
|
||||
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func private @foo110()
|
||||
# CHECK-LABEL: func private @foo111()
|
||||
|
@ -91,7 +91,7 @@ library {
|
||||
# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name0}
|
||||
# CHECK: tf_executor.fetch
|
||||
# CHECK: return
|
||||
# CHECK: func @test_func_name0
|
||||
# CHECK: func private @test_func_name0
|
||||
# CHECK-SAME: tf._resource_arg_unique_id = 0
|
||||
# CHECK-SAME: tf._resource_arg_unique_id = 0
|
||||
# CHECK: tf_executor.graph
|
||||
|
@ -4,7 +4,7 @@
|
||||
# links the function and its gradient. In MLIR a TF ops gradient function is
|
||||
# added to its list of function attributes.
|
||||
|
||||
# CHECK: func @foo0(
|
||||
# CHECK: func private @foo0(
|
||||
# CHECK: tf.gradient = @foo_grad
|
||||
|
||||
node {
|
||||
|
@ -4,8 +4,8 @@
|
||||
# functions with arg name that are the same as the graph input name
|
||||
|
||||
# CHECK: func @main(%arg0: tensor<{{.*}}i32>) -> tensor<{{.*}}i32>
|
||||
# CHECK: func @while_body
|
||||
# CHECK: func @while_cond
|
||||
# CHECK: func private @while_body
|
||||
# CHECK: func private @while_cond
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -57,7 +57,7 @@ versions {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func private @foo0()
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func private @bar0()
|
||||
|
@ -106,5 +106,5 @@ versions {
|
||||
# CHECK: func @main
|
||||
# CHECK: "tf.PartitionedCall"()
|
||||
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
|
||||
# CHECK: func @[[FUNCTION]]() -> tensor<*xui8>
|
||||
# CHECK: func private @[[FUNCTION]]() -> tensor<*xui8>
|
||||
# CHECK: return {{.*}} : tensor<*xui8>
|
||||
|
@ -86,6 +86,6 @@ versions {
|
||||
# CHECK-SAME: f = @[[FUNCTION_FOO:[a-zA-Z0-9_]*]]
|
||||
|
||||
# Find callee and verify it has the stateful attribute set.
|
||||
# CHECK: func @[[FUNCTION_FOO]]
|
||||
# CHECK: func private @[[FUNCTION_FOO]]
|
||||
# CHECK-SAME: attributes
|
||||
# CHECK-SAME: tf.signature.is_stateful
|
||||
|
@ -12,7 +12,7 @@ func @f() {
|
||||
}
|
||||
|
||||
// CHECK: func @g()
|
||||
// CHECK: func @[[NEWG]]() attributes {sym_visibility = "private"}
|
||||
// CHECK: func private @[[NEWG]]()
|
||||
func @g() {
|
||||
return
|
||||
}
|
||||
@ -22,12 +22,12 @@ func @g() {
|
||||
// CHECK-LABEL: func @f
|
||||
// 2 copies of @g
|
||||
// CHECK-DAG: func @g{{.*}}
|
||||
// CHECK-DAG: func @g{{.*}}
|
||||
// CHECK-DAG: func private @g{{.*}}
|
||||
// 4 copies of @h
|
||||
// CHECK-DAG: func @h{{.*}}
|
||||
// CHECK-DAG: func @h{{.*}}
|
||||
// CHECK-DAG: func @h{{.*}}
|
||||
// CHECK-DAG: func @h{{.*}}
|
||||
// CHECK-DAG: func private @h{{.*}}
|
||||
// CHECK-DAG: func private @h{{.*}}
|
||||
// CHECK-DAG: func private @h{{.*}}
|
||||
func @f() {
|
||||
call @g() : () -> ()
|
||||
call @g() : () -> ()
|
||||
@ -47,7 +47,7 @@ func @h() {
|
||||
// -----
|
||||
// Handle error case of infinite recursion.
|
||||
// expected-error @+1 {{reached cloning limit}}
|
||||
func @f() attributes {sym_visibility = "private"} {
|
||||
func private @f() {
|
||||
call @f() : () -> ()
|
||||
call @f() : () -> ()
|
||||
return
|
||||
|
@ -33,4 +33,4 @@ func @foo(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// CHECK: "tf.Identity"([[CALL_RESULT_REG]])
|
||||
|
||||
// Match the function name
|
||||
// CHECK: func @[[FUNCTION]]
|
||||
// CHECK: func private @[[FUNCTION]]
|
||||
|
@ -231,11 +231,11 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
|
||||
// CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
|
||||
// CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor<i32>}
|
||||
// CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
// CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor<i64>}
|
||||
// CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]])
|
||||
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
|
||||
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]])
|
||||
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
|
||||
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
|
||||
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
|
||||
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
||||
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
|
||||
@ -256,14 +256,25 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
|
||||
}
|
||||
|
||||
// Verify the result shape for the tf.PadV2 op.
|
||||
// CHECK-LABEL: const_paddings_space_to_batch_nd
|
||||
func @const_paddings_space_to_batch_nd(%arg0: tensor<1x8x2xf32>) -> (tensor<3x5x2xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
|
||||
// CHECK: "tf.PadV2"
|
||||
// CHECK-SAME: tensor<1x5x2xf32>
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 5, 2]> : tensor<3xi64>}
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 5, 3, 2]> : tensor<4xi64>}
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>}
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||
// CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>}
|
||||
// CHECK-DAG: [[VAL5:%.+]] = "tf.PadV2"(%arg0, [[VAL2]], [[VAL3]])
|
||||
// CHECK-SAME: tensor<1x15x2xf32>
|
||||
// CHECK-DAG: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
|
||||
// CHECK-DAG: [[VAL7:%.+]] = "tf.Transpose"([[VAL6]], [[VAL4]])
|
||||
// CHECK-DAG: [[VAL8:%.+]] = "tf.Reshape"([[VAL7]], [[VAL0]])
|
||||
%2 = "tf.SpaceToBatchND"(%arg0, %0, %1) : (tensor<1x8x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<3x5x2xf32>
|
||||
|
||||
// CHECK: return [[VAL8]]
|
||||
return %2 : tensor<3x5x2xf32>
|
||||
}
|
||||
|
||||
@ -757,3 +768,117 @@ func @lgamma(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "tf.Lgamma"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @imag_resize_nearest
|
||||
func @imag_resize_nearest(%arg0: tensor<1x7x7x1xi32>) -> tensor<1x3x3x1xi32> {
|
||||
%shape = "tf.Const"() {device = "", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 3, 3, 1]>
|
||||
// CHECK: [[VAL2:%.+]] = "tf.Const"() {value = dense<[1, 49, 1]>
|
||||
// CHECK: [[VAL3:%.+]] = "tf.Const"() {value = dense<[0, 2, 4, 14, 16, 18, 28, 30, 32]> : tensor<9xi32>}
|
||||
// CHECK: [[VAL4:%.+]] = "tf.Reshape"(%arg0, [[VAL2]])
|
||||
// CHECK: [[VAL5:%.+]] = "tf.GatherV2"([[VAL4]], [[VAL3]], [[VAL0]]) {batch_dims = 0 : i64}
|
||||
// CHECK: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
|
||||
// CHECK: return [[VAL6]]
|
||||
%resize = "tf.ResizeNearestNeighbor"(%arg0, %shape) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x7x7x1xi32>, tensor<2xi32>) -> tensor<1x3x3x1xi32>
|
||||
return %resize: tensor<1x3x3x1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @imag_resize_nearest_dyn_img
|
||||
func @imag_resize_nearest_dyn_img(%arg0: tensor<1x?x?x1xi32>) -> tensor<1x3x3x1xi32> {
|
||||
%shape = "tf.Const"() {device = "", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: [[VAL1:%.+]] = "tf.Const"() {value = dense<[3, 1]> : tensor<2xi32>}
|
||||
// CHECK: [[VAL2:%.+]] = "tf.Const"() {value = dense<9> : tensor<1xi32>}
|
||||
// CHECK: [[VAL3:%.+]] = "tf.Const"() {value = dense<3> : tensor<1xi32>}
|
||||
// CHECK: [[VAL4:%.+]] = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>}
|
||||
// CHECK: [[VAL5:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
|
||||
// CHECK: [[VAL6:%.+]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAL7:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
// CHECK: [[VAL8:%.+]] = "tf.Shape"(%arg0)
|
||||
// CHECK: [[VAL9:%.+]] = "tf.Cast"([[VAL8]])
|
||||
// CHECK: [[VAL10:%.+]]:4 = "tf.Unpack"([[VAL9]]) {axis = 0 : i64}
|
||||
// CHECK: [[VAL11:%.+]] = "tf.Mul"([[VAL10]]#1, [[VAL10]]#2)
|
||||
// CHECK: [[VAL12:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL7]])
|
||||
// CHECK: [[VAL13:%.+]] = "tf.ExpandDims"([[VAL10]]#3, [[VAL7]])
|
||||
// CHECK: [[VAL14:%.+]] = "tf.ConcatV2"([[VAL12]], [[VAL3]], [[VAL3]], [[VAL13]], [[VAL7]])
|
||||
// CHECK: [[VAL15:%.+]] = "tf.Cast"([[VAL10]]#1)
|
||||
// CHECK: [[VAL16:%.+]] = "tf.Div"([[VAL15]], [[VAL6]])
|
||||
// CHECK: [[VAL17:%.+]] = "tf.Mul"([[VAL16]], [[VAL5]])
|
||||
// CHECK: [[VAL18:%.+]] = "tf.Cast"([[VAL17]])
|
||||
// CHECK: [[VAL19:%.+]] = "tf.Reshape"([[VAL18]], [[VAL1]])
|
||||
// CHECK: [[VAL20:%.+]] = "tf.Mul"([[VAL19]], [[VAL10]]#2)
|
||||
// CHECK: [[VAL21:%.+]] = "tf.Cast"([[VAL10]]#2)
|
||||
// CHECK: [[VAL22:%.+]] = "tf.Div"([[VAL21]], [[VAL6]])
|
||||
// CHECK: [[VAL23:%.+]] = "tf.Mul"([[VAL22]], [[VAL5]])
|
||||
// CHECK: [[VAL24:%.+]] = "tf.Cast"([[VAL23]])
|
||||
// CHECK: [[VAL25:%.+]] = "tf.Reshape"([[VAL24]], [[VAL4]])
|
||||
// CHECK: [[VAL26:%.+]] = "tf.AddV2"([[VAL20]], [[VAL25]])
|
||||
// CHECK: [[VAL27:%.+]] = "tf.Reshape"([[VAL26]], [[VAL2]])
|
||||
// CHECK: [[VAL28:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL7]])
|
||||
// CHECK: [[VAL29:%.+]] = "tf.ExpandDims"([[VAL11]], [[VAL7]])
|
||||
// CHECK: [[VAL30:%.+]] = "tf.ExpandDims"([[VAL10]]#3, [[VAL7]])
|
||||
// CHECK: [[VAL31:%.+]] = "tf.ConcatV2"([[VAL28]], [[VAL29]], [[VAL30]], [[VAL7]])
|
||||
// CHECK: [[VAL32:%.+]] = "tf.Reshape"(%arg0, [[VAL31]])
|
||||
// CHECK: [[VAL33:%.+]] = "tf.GatherV2"([[VAL32]], [[VAL27]], [[VAL0]]) {batch_dims = 0 : i64}
|
||||
// CHECK: [[VAL34:%.+]] = "tf.Reshape"([[VAL33]], [[VAL14]])
|
||||
// CHECK: return [[VAL34]]
|
||||
%resize = "tf.ResizeNearestNeighbor"(%arg0, %shape) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x3x3x1xi32>
|
||||
return %resize: tensor<1x3x3x1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @imag_resize_nearest_full_dyn
|
||||
func @imag_resize_nearest_full_dyn(%arg0: tensor<1x?x?x1xi32>, %arg1: tensor<2xi32>) -> tensor<1x?x?x1xi32> {
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK: [[VAL1:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAL2:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAL3:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
|
||||
// CHECK: [[VAL4:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi64>}
|
||||
// CHECK: [[VAL5:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
// CHECK: [[VAL6:%.+]] = "tf.Shape"(%arg0)
|
||||
// CHECK: [[VAL7:%.+]] = "tf.Cast"([[VAL6]])
|
||||
// CHECK: [[VAL8:%.+]]:4 = "tf.Unpack"([[VAL7]]) {axis = 0 : i64}
|
||||
// CHECK: [[VAL9:%.+]] = "tf.Mul"([[VAL8]]#1, [[VAL8]]#2)
|
||||
// CHECK: [[VAL10:%.+]]:2 = "tf.Unpack"(%arg1) {axis = 0 : i64}
|
||||
// CHECK: [[VAL11:%.+]] = "tf.Mul"([[VAL10]]#0, [[VAL10]]#1)
|
||||
// CHECK: [[VAL12:%.+]] = "tf.ExpandDims"([[VAL8]]#0, [[VAL5]])
|
||||
// CHECK: [[VAL13:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL5]])
|
||||
// CHECK: [[VAL14:%.+]] = "tf.ExpandDims"([[VAL10]]#1, [[VAL5]])
|
||||
// CHECK: [[VAL15:%.+]] = "tf.ExpandDims"([[VAL8]]#3, [[VAL5]])
|
||||
// CHECK: [[VAL16:%.+]] = "tf.ConcatV2"([[VAL12]], [[VAL13]], [[VAL14]], [[VAL15]], [[VAL5]])
|
||||
// CHECK: [[VAL17:%.+]] = "tf.Cast"([[VAL8]]#1)
|
||||
// CHECK: [[VAL18:%.+]] = "tf.Cast"([[VAL10]]#0)
|
||||
// CHECK: [[VAL19:%.+]] = "tf.Div"([[VAL17]], [[VAL18]])
|
||||
// CHECK: [[VAL20:%.+]] = "tf.Range"([[VAL1]], [[VAL18]], [[VAL2]])
|
||||
// CHECK: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL19]])
|
||||
// CHECK: [[VAL22:%.+]] = "tf.Cast"([[VAL21]])
|
||||
// CHECK: [[VAL23:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL5]])
|
||||
// CHECK: [[VAL24:%.+]] = "tf.ConcatV2"([[VAL23]], [[VAL3]], [[VAL5]])
|
||||
// CHECK: [[VAL25:%.+]] = "tf.Reshape"([[VAL22]], [[VAL24]])
|
||||
// CHECK: [[VAL26:%.+]] = "tf.Mul"([[VAL25]], [[VAL8]]#2)
|
||||
// CHECK: [[VAL27:%.+]] = "tf.Cast"([[VAL8]]#2)
|
||||
// CHECK: [[VAL28:%.+]] = "tf.Cast"([[VAL10]]#1)
|
||||
// CHECK: [[VAL29:%.+]] = "tf.Div"([[VAL27]], [[VAL28]])
|
||||
// CHECK: [[VAL30:%.+]] = "tf.Range"([[VAL1]], [[VAL28]], [[VAL2]])
|
||||
// CHECK: [[VAL31:%.+]] = "tf.Mul"([[VAL30]], [[VAL29]])
|
||||
// CHECK: [[VAL32:%.+]] = "tf.Cast"([[VAL31]])
|
||||
// CHECK: [[VAL33:%.+]] = "tf.ExpandDims"([[VAL10]]#1, [[VAL5]])
|
||||
// CHECK: [[VAL34:%.+]] = "tf.ConcatV2"([[VAL3]], [[VAL33]], [[VAL5]])
|
||||
// CHECK: [[VAL35:%.+]] = "tf.Reshape"([[VAL32]], [[VAL34]])
|
||||
// CHECK: [[VAL36:%.+]] = "tf.AddV2"([[VAL26]], [[VAL35]])
|
||||
// CHECK: [[VAL37:%.+]] = "tf.Reshape"([[VAL11]], [[VAL4]])
|
||||
// CHECK: [[VAL38:%.+]] = "tf.Reshape"([[VAL36]], [[VAL37]])
|
||||
// CHECK: [[VAL39:%.+]] = "tf.ExpandDims"([[VAL8]]#0, [[VAL5]])
|
||||
// CHECK: [[VAL40:%.+]] = "tf.ExpandDims"([[VAL9]], [[VAL5]])
|
||||
// CHECK: [[VAL41:%.+]] = "tf.ExpandDims"([[VAL8]]#3, [[VAL5]])
|
||||
// CHECK: [[VAL42:%.+]] = "tf.ConcatV2"([[VAL39]], [[VAL40]], [[VAL41]], [[VAL5]])
|
||||
// CHECK: [[VAL43:%.+]] = "tf.Reshape"(%arg0, [[VAL42]])
|
||||
// CHECK: [[VAL44:%.+]] = "tf.GatherV2"([[VAL43]], [[VAL38]], [[VAL0]]) {batch_dims = 0 : i64}
|
||||
// CHECK: [[VAL45:%.+]] = "tf.Reshape"([[VAL44]], [[VAL16]])
|
||||
// CHECK: return [[VAL45]]
|
||||
%resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
|
||||
return %resize: tensor<1x?x?x1xi32>
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @unsupported_op_no_soft_placement
|
||||
func @unsupported_op_no_soft_placement() -> tensor<i32> {
|
||||
// CHECK-LABEL: func @unsupported_op_missing_soft_placement_attribute
|
||||
func @unsupported_op_missing_soft_placement_attribute() -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.UnsupportedOp"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
@ -28,6 +28,24 @@ func @unsupported_op_soft_placement_false() -> tensor<i32> {
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @assert_op_string_operand
|
||||
func @assert_op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.Assert"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf.UnsupportedOp"
|
||||
// CHECK-SAME: _xla_outside_compilation
|
||||
// CHECK: "tf.Identity"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
%t = constant dense<true> : tensor<i1>
|
||||
"tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<!tf.string>) -> ()
|
||||
%1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %2 : tensor<i32>
|
||||
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @unsupported_op
|
||||
func @unsupported_op() -> tensor<i32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
|
||||
|
@ -299,7 +299,7 @@ func @main(%arg0: tensor<i32>) -> tensor<2xf32> {
|
||||
%2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32>
|
||||
return %2 : tensor<2xf32>
|
||||
}
|
||||
func @callee(%arg0: tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32> attributes {sym_visibility = "private"} {
|
||||
func private @callee(%arg0: tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32> {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file | FileCheck %s
|
||||
|
||||
// Simple IfRegion
|
||||
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.If"
|
||||
@ -24,9 +24,9 @@ func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// -----
|
||||
|
||||
// Use if condition inside the regions
|
||||
// CHECK: func @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: "tf.Select"(%arg0, %arg2, %arg3)
|
||||
// CHECK: func @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: "tf.Select"(%arg0, %arg1, %arg2)
|
||||
func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.Add"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -48,9 +48,9 @@ func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// Constant sinking for IfRegion
|
||||
|
||||
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
|
||||
// CHECK: func private @tf.IfRegion_else() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<1.0
|
||||
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
|
||||
// CHECK: func private @tf.IfRegion_then() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<0.0
|
||||
func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
|
||||
%cst_zero = constant dense<0.0> : tensor<2xf32>
|
||||
@ -67,18 +67,18 @@ func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
|
||||
// -----
|
||||
|
||||
// Nested IfRegions
|
||||
// CHECK: func @tf.IfRegion1_else
|
||||
// CHECK: func private @tf.IfRegion1_else
|
||||
// CHECK-NEXT: "tf.Acos"
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
|
||||
// CHECK: func @tf.IfRegion1_then
|
||||
// CHECK: func private @tf.IfRegion1_then
|
||||
// CHECK-NEXT: "tf.LogicalNot"
|
||||
// CHECK-NEXT: "tf.Asin"
|
||||
// CHECK-NEXT: "tf.If"({{.+}}) {else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
|
||||
|
||||
// CHECK: func @tf.IfRegion_else
|
||||
// CHECK: func private @tf.IfRegion_else
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then
|
||||
// CHECK: func private @tf.IfRegion_then
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
|
||||
func @testNested(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
@ -169,10 +169,10 @@ func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// -----
|
||||
|
||||
// No inputs, some outputs for IfRegion
|
||||
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
|
||||
// CHECK: func private @tf.IfRegion_else() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<1.000000e+00>
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
|
||||
// CHECK: func private @tf.IfRegion_then() -> tensor<2xf32>
|
||||
// CHECK-NEXT: constant dense<0.000000e+00>
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
func @testSimple(%arg0: tensor<i1>) -> tensor<2xf32> {
|
||||
@ -193,9 +193,9 @@ func @testSimple(%arg0: tensor<i1>) -> tensor<2xf32> {
|
||||
|
||||
// No outputs, some inputs for IfRegion
|
||||
//
|
||||
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>)
|
||||
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>)
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>)
|
||||
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>)
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
func @printer(tensor<*xf32>) -> ()
|
||||
func @testNoOutputs(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> () {
|
||||
@ -214,9 +214,9 @@ func @testNoOutputs(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> () {
|
||||
|
||||
// -----
|
||||
// Check ToBool folding for IfRegion
|
||||
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: "tf.Neg"
|
||||
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: "tf.Abs"
|
||||
// CHECK-LABEL: @testToBoolFold
|
||||
func @testToBoolFold(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
@ -237,11 +237,11 @@ func @testToBoolFold(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
|
||||
// -----
|
||||
|
||||
// Simple WhileRegion
|
||||
// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"}
|
||||
// CHECK: func private @tf.WhileRegion_body{{.+}}
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: constant dense<1>
|
||||
// CHECK: "tf.Sub"
|
||||
// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"}
|
||||
// CHECK:func private @tf.WhileRegion_cond{{.+}}
|
||||
// CHECK: constant dense<0>
|
||||
// CHECK: "tf.NotEqual"
|
||||
// CHECK-LABEL: testValidWhileRegion
|
||||
@ -275,11 +275,11 @@ func @testValidWhileRegion(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor
|
||||
// -----
|
||||
|
||||
// WhileRegion with type mismatch
|
||||
// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"}
|
||||
// CHECK: func private @tf.WhileRegion_body{{.+}}
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: constant dense<1>
|
||||
// CHECK: "tf.Sub"
|
||||
// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"}
|
||||
// CHECK:func private @tf.WhileRegion_cond{{.+}}
|
||||
// CHECK: constant dense<0>
|
||||
// CHECK: "tf.NotEqual"
|
||||
// CHECK-LABEL: testWhileRegionTypeMismatch
|
||||
@ -309,11 +309,11 @@ func @testWhileRegionTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) ->
|
||||
// -----
|
||||
|
||||
// WhileRegion with constant sinking
|
||||
// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"}
|
||||
// CHECK: func private @tf.WhileRegion_body{{.+}}
|
||||
// CHECK: constant dense<1>
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: "tf.Sub"
|
||||
// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"}
|
||||
// CHECK:func private @tf.WhileRegion_cond{{.+}}
|
||||
// CHECK: constant dense<0>
|
||||
// CHECK: "tf.NotEqual"
|
||||
// CHECK-LABEL: testWhileRegionConstantSink
|
||||
@ -342,12 +342,12 @@ func @testWhileRegionConstantSink(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) ->
|
||||
// -----
|
||||
|
||||
// WhileRegion with implicitly captured extern value in cond
|
||||
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: constant dense<1>
|
||||
// CHECK: "tf.Sub"
|
||||
// CHECK: return %{{.+}}, %{{.+}}, %arg2 : tensor<*xf32>, tensor<i32>, tensor<i32>
|
||||
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: "tf.NotEqual"(%arg1, %arg2)
|
||||
// CHECK-LABEL: testWhileRegionExternInCond
|
||||
func @testWhileRegionExternInCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
|
||||
@ -376,12 +376,12 @@ func @testWhileRegionExternInCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %a
|
||||
// -----
|
||||
|
||||
// WhileRegion with implicitly captured extern value in body
|
||||
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: %0 = "tf.Add"(%arg0, %arg0)
|
||||
// CHECK: %1 = "tf.Sub"(%arg1, %arg2)
|
||||
// CHECK: return %0, %1, %arg2
|
||||
|
||||
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: constant dense<0>
|
||||
// CHECK: "tf.NotEqual"
|
||||
|
||||
@ -412,9 +412,9 @@ func @testWhileRegionExternInBody(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %a
|
||||
// -----
|
||||
|
||||
// WhileRegion with implicitly captured extern value in cond and body
|
||||
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
|
||||
// CHECK: return %{{.+}}, %{{.+}}, %arg2, %arg3
|
||||
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
|
||||
// CHECK-LABEL: testWhileRegionExternInBodyAndCond
|
||||
func @testWhileRegionExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<4> : tensor<i32>
|
||||
@ -443,9 +443,9 @@ func @testWhileRegionExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i
|
||||
// -----
|
||||
|
||||
// WhileRegion with same value implicitly captured in cond and body
|
||||
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: return %{{.+}}, %{{.+}}, %arg2
|
||||
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
|
||||
// CHECK-LABEL: testWhileRegionSameExternInBodyAndCond
|
||||
func @testWhileRegionSameExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<4> : tensor<i32>
|
||||
@ -559,9 +559,9 @@ func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor<
|
||||
// -----
|
||||
|
||||
// Almost trivially transformable with extern values
|
||||
// CHECK: func @tf.WhileRegion_body
|
||||
// CHECK: func private @tf.WhileRegion_body
|
||||
// CHECK: call @while_body
|
||||
// CHECK: @tf.WhileRegion_cond
|
||||
// CHECK: func private @tf.WhileRegion_cond
|
||||
// CHECK: call @while_cond
|
||||
// CHECK-LABEL: testWhileRegionExtern
|
||||
func @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<i1>
|
||||
@ -589,9 +589,9 @@ func @testWhileRegionExtern(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tenso
|
||||
// -----
|
||||
|
||||
// Almost trivially transformable, mismatching block arguments
|
||||
// CHECK: func @tf.WhileRegion_body
|
||||
// CHECK: func private @tf.WhileRegion_body
|
||||
// CHECK: call @while_body
|
||||
// CHECK: @tf.WhileRegion_cond
|
||||
// CHECK: func private @tf.WhileRegion_cond
|
||||
// CHECK: call @while_cond
|
||||
// CHECK-LABEL: testWhileRegionBlockArgMismatch
|
||||
func @while_cond(%arg0 : tensor<i32>, %arg1 : tensor<*xf32>) -> tensor<i1>
|
||||
|
@ -17,8 +17,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @callee
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<*xf32> attributes {sym_visibility = "private", tf.signature.is_stateful} {
|
||||
// CHECK-NOT: func private @callee
|
||||
func private @callee(%arg0: tensor<!tf.resource>) -> tensor<*xf32> attributes {tf.signature.is_stateful} {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<!tf.resource>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -644,7 +644,7 @@ func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %ar
|
||||
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK: func private @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD1]]
|
||||
@ -691,7 +691,7 @@ func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.res
|
||||
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK: func private @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
|
||||
// CHECK-NEXT: return %[[ADD]]
|
||||
|
||||
@ -743,7 +743,7 @@ func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK: func private @callee_resource_lifted(%[[A0:.*]]: tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return %[[A0]]
|
||||
|
||||
// -----
|
||||
@ -1249,3 +1249,40 @@ func @callee(%arg0: !tf_res) -> tensor<i1> {
|
||||
// CHECK-NEXT: return [[TRUE]] :
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests passthrough tf.Cast ops are removed.
|
||||
|
||||
!tf_res = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
|
||||
// CHECK-LABEL: func @tpu_computation
|
||||
func @tpu_computation(%arg0: !tf_res) {
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.While"(%arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (!tf_res) -> !tf_res
|
||||
%1 = "tf.WhileRegion"(%arg0) ( {
|
||||
^cond(%carg0: !tf_res):
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
"tf.Yield"(%2) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body(%barg0: !tf_res):
|
||||
// CHECK-NOT: tf.Cast
|
||||
%2 = "tf.Cast"(%barg0) : (!tf_res) -> !tf_res
|
||||
"tf.Yield"(%2) : (!tf_res) -> ()
|
||||
}) {is_stateless = false} : (!tf_res) -> !tf_res
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @while_cond(%arg0: !tf_res) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body
|
||||
func @while_body(%arg0: !tf_res) -> !tf_res {
|
||||
// CHECK-NOT: tf.Cast
|
||||
%0 = "tf.Cast"(%arg0) : (!tf_res) -> !tf_res
|
||||
return %0 : !tf_res
|
||||
}
|
||||
|
@ -439,16 +439,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
return %arg0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// Test not updating call site if a std.call is used.
|
||||
// Test iteratively updating call site if a std.call is used.
|
||||
// CHECK-LABEL: func @call_partitioned_call2(
|
||||
// CHECK-SAME: -> tensor<*xi32>
|
||||
// CHECK-SAME: -> tensor<1xi32>
|
||||
func @call_partitioned_call2() -> tensor<*xi32> {
|
||||
// CHECK: () -> tensor<*xi32>
|
||||
// CHECK: () -> tensor<1xi32>
|
||||
%0 = call @partitioned_called_func2() : () -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
// CHECK-LABEL: func @partitioned_called_func2(
|
||||
// CHECK-SAME: -> tensor<*xi32>
|
||||
// CHECK-SAME: -> tensor<1xi32>
|
||||
func @partitioned_called_func2() -> (tensor<*xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = tensor_cast %0 : tensor<1xi32> to tensor<*xi32>
|
||||
|
@ -287,14 +287,14 @@ func @main(%arg0: tensor<i1>) -> () {
|
||||
}
|
||||
|
||||
// CHECK: func @callee(%[[AARG0:.*]]: tensor<!tf.resource>, %[[AARG1:.*]]: tensor<i1>) -> tensor<!tf.resource>
|
||||
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> attributes {sym_visibility = "public"} {
|
||||
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> {
|
||||
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
|
||||
// CHECK: tf.StackPushV2"
|
||||
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// CHECK: func @callee_stack_decomposed(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
// CHECK: func private @callee_stack_decomposed(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
|
||||
// CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]])
|
||||
@ -326,8 +326,8 @@ func @main(%arg0: tensor<i1>) -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @callee(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @callee(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
func private @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> {
|
||||
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
|
||||
@ -348,7 +348,7 @@ func @main() -> () {
|
||||
return
|
||||
}
|
||||
// CHECK: func @callee()
|
||||
func @callee() -> () attributes {sym_visibility = "public"} {
|
||||
func @callee() -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
|
@ -432,7 +432,7 @@ func @main() -> () {
|
||||
}
|
||||
// CHECK-LABEL: func @callee
|
||||
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sym_visibility = "public"} {
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
@ -442,7 +442,7 @@ func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sy
|
||||
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
// CHECK: func private @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
|
||||
@ -480,8 +480,8 @@ func @main() -> () {
|
||||
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @callee(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @callee(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func private @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
|
||||
@ -508,8 +508,8 @@ func @main() -> () {
|
||||
%call = "tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> tensor<i32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @callee() -> tensor<i32>
|
||||
func @callee() -> tensor<i32> attributes {sym_visibility = "public"} {
|
||||
// CHECK: func private @callee() -> tensor<i32>
|
||||
func @callee() -> tensor<i32> {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5xf32>>>
|
||||
// CHECK: "tf.AssignVariableOp"
|
||||
@ -567,7 +567,7 @@ func @main() -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @callee
|
||||
// CHECK-LABEL: func private @callee
|
||||
// CHECK-SAME: %[[VAR:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[GVAR:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sym_visibility = "private"} {
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
|
@ -472,14 +472,14 @@ func @main(%arg0: tensor<i1>) -> () {
|
||||
}
|
||||
|
||||
// CHECK: func @callee(%[[AARG0:.*]]: tensor<!tf.variant<tensor<f32>>>, %[[AARG1:.*]]: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>> attributes {sym_visibility = "public"} {
|
||||
func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>> {
|
||||
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
|
||||
// CHECK: "tf.TensorListPushBack"
|
||||
%push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
return %push : tensor<!tf.variant<tensor<f32>>>
|
||||
}
|
||||
|
||||
// CHECK: func @callee_tensorlist_decomposed(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
|
||||
// CHECK: func private @callee_tensorlist_decomposed(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
|
||||
// CHECK-NOT: "tf.TensorListPushBack"
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
@ -514,7 +514,7 @@ func @main(%arg0: tensor<i1>) -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @callee(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
|
||||
// CHECK: func private @callee(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
|
||||
func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>> attributes {sym_visibility = "private"} {
|
||||
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
|
||||
|
||||
@ -533,12 +533,12 @@ func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tens
|
||||
// Tests PartitionedCall op with no signature change on callee.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
func @main() {
|
||||
"tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: func @callee()
|
||||
func @callee() -> () attributes {sym_visibility = "public"} {
|
||||
// CHECK: func private @callee()
|
||||
func @callee() {
|
||||
%elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.EmptyTensorList
|
||||
|
@ -62,7 +62,7 @@ class TestModule(tf.Module):
|
||||
# CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["caller"]
|
||||
# CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL]]
|
||||
#
|
||||
# CHECK: func @[[CALLEE_INTERNAL]]
|
||||
# CHECK: func private @[[CALLEE_INTERNAL]]
|
||||
# CHECK-NOT: tf_saved_model.exported_names
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
|
@ -35,8 +35,8 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
# CHECK-SAME: else_branch = @[[else]]
|
||||
# CHECK-SAME: then_branch = @[[then]]
|
||||
|
||||
# CHECK: func @[[else]](
|
||||
# CHECK: func @[[then]](
|
||||
# CHECK: func private @[[else]](
|
||||
# CHECK: func private @[[then]](
|
||||
|
||||
|
||||
def Test():
|
||||
|
@ -26,11 +26,11 @@ import tempfile
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
|
||||
# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[init:.*]]]} : () -> ()
|
||||
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"}
|
||||
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"}
|
||||
|
||||
# CHECK: func [[init]]
|
||||
# CHECK: func @[[init]]
|
||||
# CHECK-SAME: [[ARG0:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset0]]}
|
||||
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset1]]}
|
||||
# CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"()
|
||||
|
@ -34,9 +34,9 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
# CHECK-SAME: producer
|
||||
|
||||
# CHECK: "tf_saved_model.global_tensor"()
|
||||
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
|
||||
# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[init:.*]]]} : () -> ()
|
||||
|
||||
# CHECK: func [[init]]
|
||||
# CHECK: func @[[init]]
|
||||
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
|
||||
# CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
|
||||
# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
|
||||
@ -89,4 +89,4 @@ def Test():
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(Test)
|
||||
common_v1.do_test(Test, canonicalize=True)
|
||||
|
@ -0,0 +1,80 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# RUN: %p/import_restore_v1 | FileCheck %s
|
||||
|
||||
# pylint: disable=missing-docstring,line-too-long
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# Verify that the tf.versions attribute exists. It is difficult to enforce
|
||||
# contents, since the version numbers change over time. The conversion logic
|
||||
# itself is verified in the common graphdef converter, so here just assert
|
||||
# it is being invoked.
|
||||
# CHECK: module
|
||||
# CHECK-SAME: tf.versions
|
||||
# CHECK-SAME: bad_consumers
|
||||
# CHECK-SAME: min_consumer
|
||||
# CHECK-SAME: producer
|
||||
|
||||
# CHECK: tf_saved_model.session_initializer
|
||||
# CHECK-SAME: initializers = [@[[restore:.*]]]
|
||||
|
||||
# CHECK: "tf_saved_model.asset"()
|
||||
# CHECK-SAME: {filename = [[filename:.*]], sym_name = "[[sym_name:.*]]"} : () -> ()
|
||||
|
||||
# CHECK: func @[[restore]](
|
||||
# CHECK-SAME: [[variable_path:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[sym_name]]}
|
||||
# CHECK-SAME: tf_saved_model.exported_names = ["{{__tf_saved_model_session_initializer.*}}"]
|
||||
# CHECK: [[v0:%.*]] = "tf.RestoreV2"([[variable_path]]
|
||||
# CHECK: [[v1:%.*]] = "tf.Identity"([[v0]])
|
||||
# CHECK: [[handle:%.*]] = "tf.VarHandleOp"
|
||||
# CHECK-SAME: shared_name = [[shared_name:".*"]]
|
||||
# CHECK: "tf.AssignVariableOp"([[handle]], [[v1]])
|
||||
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: tf_saved_model.exported_names = ["key"]
|
||||
# CHECK: tf.VarHandleOp
|
||||
# CHECK-SAME: shared_name = [[shared_name]]
|
||||
|
||||
|
||||
def Test():
|
||||
|
||||
x = tf.constant([[1.0], [1.0], [1.0]])
|
||||
y = tf.compat.v1.get_variable(
|
||||
name='y',
|
||||
shape=(1, 3),
|
||||
initializer=tf.random_normal_initializer(),
|
||||
trainable=True)
|
||||
r = tf.matmul(x, y)
|
||||
|
||||
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
|
||||
|
||||
return {
|
||||
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs={'x': tensor_info_x},
|
||||
outputs={'r': tensor_info_r},
|
||||
method_name='some_function'))
|
||||
}, None, None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
common_v1.set_tf_options()
|
||||
common_v1.do_test(Test, use_lite=True)
|
@ -4,7 +4,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: tf_saved_model.session_initializer
|
||||
"tf_saved_model.session_initializer"() {
|
||||
initializer = @init
|
||||
initializers = [@init]
|
||||
} : () -> ()
|
||||
|
||||
// CHECK: tf_saved_model.asset
|
||||
|
@ -277,7 +277,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function does not exist}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -285,7 +285,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should have no output}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
@ -298,7 +298,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
// expected-error@+1 {{there must be no more than one session_initializer op}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
@ -336,7 +336,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function does not exist}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -344,7 +344,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should have no output}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
|
||||
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
@ -356,9 +356,9 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
// expected-error@+1 {{there must be no more than one session_initializer op}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
|
||||
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
@ -371,7 +371,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should be exported}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
@ -382,7 +382,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{the initializer function should have only one exported name}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() attributes { tf_saved_model.exported_names = ["a", "b"] } {
|
||||
return
|
||||
}
|
||||
|
@ -111,14 +111,14 @@ module attributes {tf_saved_model.semantics} {
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
@ -145,14 +145,14 @@ module attributes {tf_saved_model.semantics} {
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
@ -178,14 +178,14 @@ module attributes {tf_saved_model.semantics} {
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
@ -211,8 +211,8 @@ module attributes {tf_saved_model.semantics} {
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
// CHECK: func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
|
@ -9,14 +9,14 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
// Test case: No matching function for the given session initializer.
|
||||
// expected-error@+1 {{'tf_saved_model.session_initializer' op the initializer function does not exist}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
// Test case: Invalid multiple blocks in the initializer funcion.
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
// expected-error@+1 {{expects exactly one block in the MLIR function}}
|
||||
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
|
||||
br ^bb1
|
||||
@ -32,7 +32,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
// CHECK: func @init()
|
||||
// CHECK: tf.Const
|
||||
// CHECK: return
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
|
||||
"tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return
|
||||
@ -48,7 +48,7 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
// CHECK-NOT: tf.Const
|
||||
// CHECK-NOT: tf.AssignAddVariableOp
|
||||
// CHECK: return
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<2x8xi32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<2xi32>>>
|
||||
@ -69,7 +69,7 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
// CHECK-NOT: tf.Const
|
||||
// CHECK-NOT: tf.AssignAddVariableOp
|
||||
// CHECK: return
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
|
||||
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<2x8xi32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<2xi32>>>
|
||||
|
@ -6,12 +6,10 @@
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<64xf32>>>
|
||||
// CHECK-SAME: %[[ARG_2:.*]]: tensor<*x!tf.resource<tensor<16xf32>>>
|
||||
// CHECK-SAME: %[[ARG_3:.*]]: tensor<!tf.string>
|
||||
func @merge_same_device_variables(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<16xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"},
|
||||
%arg3: tensor<!tf.string>) {
|
||||
%arg2: tensor<*x!tf.resource<tensor<16xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) {
|
||||
// CHECK-NEXT: %[[ID_0:.*]] = "tf.IdentityN"(%[[ARG_0]])
|
||||
%id0 = "tf.IdentityN"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
: (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
@ -19,15 +17,27 @@ func @merge_same_device_variables(
|
||||
%read0 = "tf.ReadVariableOp"(%id0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
|
||||
%read2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<16xf32>>>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]])
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
// CHECK: tf._TPUCompileMlir
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main(%arg0: tensor<32xf32> {tf.aliasing_output = 0 : i64},
|
||||
// CHECK-SAME: %arg1: tensor<64xf32>, %arg2: tensor<16xf32>)
|
||||
%0:2 = "tf._TPUCompileMlir"() {
|
||||
metadata = "",
|
||||
mlir_module = "module attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor<32xf32>, %arg1: tensor<64xf32>, %arg2: tensor<16xf32>) -> (tensor<32xf32>, tensor<16xf32>) {\0A %0:2 = \22tf.A\22(%arg0, %arg1, %arg2) : (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>) -> (tensor<32xf32>, tensor<16xf32>)\0A return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>\0A }\0A}"
|
||||
} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %0#0, %0#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
// CHECK: %[[EXE:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[COMPILE]]#1)
|
||||
// CHECK-SAME: device_var_reads_indices = [0, 1],
|
||||
// CHECK-SAME: device_var_updates_indices = [0, -1]
|
||||
%execute:2 = "tf_device.launch"() ( {
|
||||
%0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) {
|
||||
%0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %compile#1) {
|
||||
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>],
|
||||
Tresults = [tensor<32xf32>, tensor<16xf32>]}
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<2x!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
|
||||
tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>)
|
||||
// CHECK-NEXT: tf_device.return
|
||||
@ -44,26 +54,35 @@ func @merge_same_device_variables(
|
||||
// Tests that the pass do not check devices for replicated region.
|
||||
|
||||
// CHECK-LABEL: func @merge_replicated_variables
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.string>,
|
||||
// CHECK-SAME: %[[ARG_2:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
// CHECK-SAME: %[[ARG_3:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
// CHECK-SAME: %[[ARG_2:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
func @merge_replicated_variables(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<!tf.string>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg3: tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
// CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// CHECK-NEXT: tf_device.replicate([%[[ARG_2]], %[[ARG_3]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
tf_device.replicate([%arg2, %arg3] as %r: tensor<*x!tf.resource<tensor<32xf32>>>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
// CHECK: tf._TPUCompileMlir
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32> {tf.aliasing_output = 0 : i64})
|
||||
%0:2 = "tf._TPUCompileMlir"() {
|
||||
metadata = "",
|
||||
mlir_module = "module attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> (tensor<32xf32>) {\0A %0 = \22tf.A\22(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<32xf32>)\0A return %0 : tensor<32xf32>\0A }\0A}"
|
||||
} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %0#0, %0#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
// CHECK: tf_device.replicate([%[[ARG_1]], %[[ARG_2]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
tf_device.replicate([%arg1, %arg2] as %r: tensor<*x!tf.resource<tensor<32xf32>>>) {n = 2 : i32} {
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[ARG_1]])
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[COMPILE]]#1)
|
||||
// CHECK-SAME: device_var_reads_indices = [1],
|
||||
// CHECK-SAME: device_var_updates_indices = [0]
|
||||
%read1 = "tf.ReadVariableOp"(%r) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%0 = "tf.TPUExecute"(%read0, %read1, %arg1)
|
||||
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
|
||||
%0 = "tf.TPUExecute"(%read0, %read1, %compile#1)
|
||||
: (tensor<32xf32>, tensor<32xf32>, tensor<2x!tf.string>) -> tensor<32xf32>
|
||||
tf_device.return %0 : tensor<32xf32>
|
||||
}) {device = ""} : () -> tensor<32xf32>
|
||||
// CHECK-NEXT: tf_device.return
|
||||
@ -86,7 +105,6 @@ func @merge_replicated_variables(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<64xf32>>>
|
||||
// CHECK-SAME: %[[ARG_2:.*]]: tensor<32xf32>
|
||||
// CHECK-SAME: %[[ARG_3:.*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_4:.*]]: tensor<*x!tf.resource<tensor<8xf32>>>
|
||||
// CHECK-SAME: %[[ARG_5:.*]]: tensor<*x!tf.resource<tensor<2xf32>>>
|
||||
// CHECK-SAME: %[[ARG_6:.*]]: tensor<2xf32>
|
||||
@ -94,7 +112,6 @@ func @interferencing_accesses(
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
|
||||
%arg2: tensor<32xf32>,
|
||||
%arg3: tensor<!tf.string>,
|
||||
%arg4: tensor<*x!tf.resource<tensor<8xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<2xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
|
||||
%arg6: tensor<2xf32>) -> (tensor<8xf32>) {
|
||||
@ -108,15 +125,26 @@ func @interferencing_accesses(
|
||||
"tf.AssignVariableOp"(%arg5, %arg6) : (tensor<*x!tf.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
|
||||
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
|
||||
%read2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]])
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
// CHECK: tf._TPUCompileMlir
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32> {tf.aliasing_output = 1 : i64})
|
||||
%0:2 = "tf._TPUCompileMlir"() {
|
||||
metadata = "",
|
||||
mlir_module = "module attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> (tensor<32xf32>) {\0A %0 = \22tf.A\22(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<32xf32>)\0A return %0 : tensor<32xf32>\0A }\0A}"
|
||||
} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
tf_device.return %0#0, %0#1 : tensor<!tf.string>, tensor<2x!tf.string>
|
||||
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
|
||||
// CHECK: %[[EXE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[COMPILE]]#1)
|
||||
// CHECK-SAME: device_var_reads_indices = [1, 2],
|
||||
// CHECK-SAME: device_var_updates_indices = [1, -1]
|
||||
%execute:3 = "tf_device.launch"() ( {
|
||||
%0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) {
|
||||
%0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %compile#1) {
|
||||
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>],
|
||||
Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>]}
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<!tf.string>)
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<2x!tf.string>)
|
||||
-> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
|
||||
tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
|
||||
|
@ -859,7 +859,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: func @nested_func
|
||||
// CHECK-SAME: func private @nested_func
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
@ -908,7 +908,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: func @referenced_func
|
||||
// CHECK-SAME: func private @referenced_func
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: "tf_device.launch"
|
||||
@ -1007,7 +1007,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-COUNT-2: call @referenced_func
|
||||
// CHECK-COUNT-1: func @referenced_func
|
||||
// CHECK-COUNT-1: func private @referenced_func
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: "tf_device.launch"
|
||||
@ -1161,13 +1161,13 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-SAME: func @referenced_func3
|
||||
// CHECK-SAME: func private @referenced_func3
|
||||
// CHECK-SAME: tf.I
|
||||
// CHECK-SAME: func @referenced_func2
|
||||
// CHECK-SAME: func private @referenced_func2
|
||||
// CHECK-SAME: tf.H
|
||||
// CHECK-SAME: func @referenced_func1
|
||||
// CHECK-SAME: func private @referenced_func1
|
||||
// CHECK-SAME: tf.G
|
||||
// CHECK-SAME: func @referenced_func0
|
||||
// CHECK-SAME: func private @referenced_func0
|
||||
// CHECK-SAME: tf.F
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
|
@ -44,9 +44,9 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
%10 = "tf.Identity"(%9) {device = ""} : (tensor<i1>) -> tensor<i1>
|
||||
return %10 : tensor<i1>
|
||||
}
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
// CHECK-LABEL: func private @_func
|
||||
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
|
||||
func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||
%1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||
%2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
@ -112,9 +112,9 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSI
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
// CHECK-LABEL: func private @_func
|
||||
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
|
||||
func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
|
||||
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
|
@ -112,7 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
|
||||
LogicalResult ConvertWhileOp(WhileOp while_op) {
|
||||
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
|
||||
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
|
||||
while_op.is_stateless(), while_op.parallel_iterations());
|
||||
while_op.output_shapes(), while_op.parallel_iterations(),
|
||||
while_op.is_stateless());
|
||||
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
||||
|
||||
YieldOp cond_yield =
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user