commit
04b7141f30
26
.bazelrc
26
.bazelrc
@ -396,17 +396,21 @@ build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_P
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
|
||||
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
|
||||
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
|
@ -747,9 +747,7 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
||||
}
|
||||
|
||||
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->ClearCachesAndThreadExecutors();
|
||||
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
|
||||
}
|
||||
|
||||
// Set server_def on the context, possibly updating it.
|
||||
|
@ -79,6 +79,8 @@ class AbstractContextInterface {
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -1017,10 +1017,10 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
inst->getName().print(os);
|
||||
// Print out attributes except for large elementsattributes (which should
|
||||
// rarely be the cause why the legalization didn't happen).
|
||||
if (!inst->getAttrList().getAttrs().empty()) {
|
||||
if (!inst->getMutableAttrDict().getAttrs().empty()) {
|
||||
os << " {";
|
||||
bool first = true;
|
||||
for (auto& named_attr : inst->getAttrList().getDictionary()) {
|
||||
for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) {
|
||||
os << (!first ? ", " : "");
|
||||
first = false;
|
||||
named_attr.first.print(os);
|
||||
|
@ -48,10 +48,14 @@ class TfDevice_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<TfDevice_Dialect, mnemonic, traits> { }
|
||||
|
||||
def TfDevice_LaunchOp : TfDevice_Op<"launch",
|
||||
[SingleBlockImplicitTerminator<"ReturnOp">]>
|
||||
{
|
||||
let summary = [{The `tf_device.launch` op captures all needed live-in values
|
||||
and launches containing operations on target device.}];
|
||||
[SingleBlockImplicitTerminator<"ReturnOp">]> {
|
||||
let summary = [{
|
||||
The `tf_device.launch` op launches containing operations on target device.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
This op captures all needed live-in values.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$device
|
||||
@ -85,8 +89,8 @@ def TfDevice_LaunchOp : TfDevice_Op<"launch",
|
||||
|
||||
def TfDevice_ReturnOp : TfDevice_Op<"return", [Terminator]> {
|
||||
let summary = [{
|
||||
The `tf_device.return` operation terminates and returns values from
|
||||
`tf_device.launch` operation;
|
||||
The `tf_device.return` operation terminates and returns values from a
|
||||
`tf_device` dialect operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -121,7 +125,6 @@ def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> {
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getFunc() { return func(); }
|
||||
StringRef getDevice() { return device(); }
|
||||
FunctionType getFuncType();
|
||||
}];
|
||||
}
|
||||
|
||||
@ -281,4 +284,51 @@ For example:
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TfDevice_ClusterOp : TfDevice_Op<"cluster",
|
||||
[SingleBlockImplicitTerminator<"ReturnOp">]> {
|
||||
let summary = [{
|
||||
The `tf_device.cluster` op wraps containing operations in a region.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
This op can be used to group operations, and captures all needed live-in values.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs
|
||||
Variadic<AnyType>:$results
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Block &GetBody() { return getOperation()->getRegion(0).front(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def TfDevice_ClusterFuncOp : TfDevice_Op<"cluster_func", []> {
|
||||
let summary = [{
|
||||
The `tf_device.cluster_func` launches a function containing the body of a
|
||||
cluster.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
This op is used for outlining a cluster.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$func,
|
||||
Variadic<AnyType>:$operands
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<AnyType>:$results
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
StringRef getFunc() { return func(); }
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TF_DEVICE_DIALECT
|
||||
|
@ -91,3 +91,55 @@ func @nodep_multiple_outside_compilation() -> () {
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
|
||||
|
||||
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
|
||||
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.launch"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
|
||||
|
||||
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
|
||||
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_LAUNCH_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_LAUNCH_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2, %3 = "tf_device.launch"() ( {
|
||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// TODO(b/154363171): Add test cases for when output of outside compilation is returned by parallel_execute.
|
||||
|
@ -65,7 +65,8 @@ constexpr char kBadTPUReplicateAttrMsg[] =
|
||||
"requires '_tpu_replicate' string attribute";
|
||||
|
||||
// Mapping for `_tpu_replicate` attribute to TPUReplicateMetadata attributes.
|
||||
using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttributeList, 8>;
|
||||
using MetadataMap =
|
||||
llvm::SmallDenseMap<llvm::StringRef, MutableDictionaryAttr, 8>;
|
||||
|
||||
// Mapping for `_tpu_replicate` attribute to ops of a cluster.
|
||||
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef,
|
||||
@ -83,7 +84,7 @@ struct TPUClusterFormation
|
||||
LogicalResult CollectMetadata(Operation* op, MetadataMap* metadata_map) {
|
||||
auto result =
|
||||
op->walk([&](TF::TPUReplicateMetadataOp metadata_op) -> WalkResult {
|
||||
NamedAttributeList attrs = metadata_op.getAttrs();
|
||||
MutableDictionaryAttr attrs = metadata_op.getAttrs();
|
||||
|
||||
// Missing or bad `_tpu_replicate` attribute.
|
||||
auto tpu_replicate_attr = attrs.get(kTPUReplicateAttr);
|
||||
|
@ -103,6 +103,18 @@ tf_device::LaunchOp CreateLaunchOpForCluster(OpBuilder* builder,
|
||||
return launch_op;
|
||||
}
|
||||
|
||||
// Propagates the return from `parallel_execute_op` to parent replicate
|
||||
// op if it exists.
|
||||
void PropagateParallelExecuteReturnToReplicate(
|
||||
tf_device::ParallelExecuteOp parallel_execute_op) {
|
||||
// Update the return for the parallel_execute op parent.
|
||||
auto replicate = llvm::dyn_cast_or_null<tf_device::ReplicateOp>(
|
||||
parallel_execute_op.getParentOp());
|
||||
if (replicate)
|
||||
replicate.GetBody().getTerminator()->setOperands(
|
||||
parallel_execute_op.execute_outputs());
|
||||
}
|
||||
|
||||
// Creates a `parallel_execute` op in place of launch with 'clusters` and
|
||||
// 'launch` as regions.
|
||||
void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
@ -111,14 +123,8 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
// Create parallel_execute regions. The original TPU cluster computation
|
||||
// is the extra region.
|
||||
int num_regions = 1 + clusters.size();
|
||||
// TODO(b/154363171): Correctly determine output_types. Add tests to confirm
|
||||
// that the types for parallel_execute_op match the concatenated output
|
||||
// types of the contained regions.
|
||||
// TODO(b/154363171): Remap the results of the `launch` op to use the
|
||||
// results of the `parallel_execute` op.
|
||||
llvm::SmallVector<Type, 8> concatenated_output_types;
|
||||
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
|
||||
launch.getLoc(), num_regions, concatenated_output_types);
|
||||
launch.getLoc(), num_regions, launch.results().getTypes());
|
||||
|
||||
// Move outside compilation clusters to parallel_execute regions.
|
||||
for (const auto& cluster : llvm::enumerate(clusters)) {
|
||||
@ -131,7 +137,10 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
CreateLaunchOpForCluster(&builder, cluster_ops.back());
|
||||
MoveClusterOpsToLaunchOp(launch_op, cluster_ops);
|
||||
builder.setInsertionPointToEnd(&outside_block);
|
||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), launch.getResults());
|
||||
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
||||
// regions either through communication with TPU parallel_execute regions
|
||||
// or modifying parallel_execute returns.
|
||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), ArrayRef<Value>{});
|
||||
}
|
||||
|
||||
// Move the launch body to last parallel_execute block.
|
||||
@ -140,6 +149,11 @@ void CreateParallelExecuteFromClusters(tf_device::LaunchOp launch,
|
||||
builder.setInsertionPointToEnd(&inside_block);
|
||||
builder.create<tf_device::ReturnOp>(launch.getLoc(), launch.getResults());
|
||||
launch.getOperation()->moveBefore(inside_block.getTerminator());
|
||||
|
||||
PropagateParallelExecuteReturnToReplicate(parallel_execute_op);
|
||||
// TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute
|
||||
// regions either through communication with TPU parallel_execute regions
|
||||
// or modifying parallel_execute returns.
|
||||
}
|
||||
|
||||
void TPUExtractOutsideCompilation::runOnFunction() {
|
||||
|
@ -113,7 +113,7 @@ void BreakUpIslands::runOnFunction() {
|
||||
state.addOperands(operands);
|
||||
Operation* new_op = builder.createOperation(state);
|
||||
item.replaceAllUsesWith(new_op);
|
||||
new_op->setAttrs(item.getAttrList());
|
||||
new_op->setAttrs(item.getMutableAttrDict());
|
||||
item.erase();
|
||||
}
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
||||
op.getResult(0).replaceAllUsesWith(replacement->getResult(0));
|
||||
for (int i : llvm::seq<int>(1, op.getNumResults()))
|
||||
op.getResult(i).replaceAllUsesWith(replacement->getResult(i + 1));
|
||||
replacement->setAttrs(op.getAttrList());
|
||||
replacement->setAttrs(op.getMutableAttrDict());
|
||||
op.erase();
|
||||
continue;
|
||||
} else if (op.getName().getStringRef() == "_tf.NextIteration.sink") {
|
||||
@ -177,7 +177,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
||||
frame_name_to_loop[frame.getValue()];
|
||||
replacement = builder.create<tf_executor::NextIterationSinkOp>(
|
||||
loc, srcOp.token(), operands, ArrayRef<NamedAttribute>{});
|
||||
replacement->setAttrs(op.getAttrList());
|
||||
replacement->setAttrs(op.getMutableAttrDict());
|
||||
op.erase();
|
||||
continue;
|
||||
} else if (op.getName().getStringRef() == "_tf.LoopCond") {
|
||||
@ -220,7 +220,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
||||
// Create the operation inside the island
|
||||
OpBuilder island_builder = OpBuilder::atBlockEnd(&island.GetBody());
|
||||
Operation *inner_op = island_builder.createOperation(result);
|
||||
inner_op->setAttrs(op.getAttrList());
|
||||
inner_op->setAttrs(op.getMutableAttrDict());
|
||||
|
||||
// Add the terminator for the island
|
||||
SmallVector<Value, 8> ret_vals(inner_op->getResults());
|
||||
@ -230,7 +230,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
||||
// Copy the attributes from the original operation to the replacement and
|
||||
// remap the results.
|
||||
if (!isa<tf_executor::IslandOp>(replacement))
|
||||
replacement->setAttrs(op.getAttrList());
|
||||
replacement->setAttrs(op.getMutableAttrDict());
|
||||
for (int i : llvm::seq<int>(0, op.getNumResults()))
|
||||
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
|
||||
op.erase();
|
||||
|
@ -136,7 +136,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(state);
|
||||
replacement->setAttrs(wrapped_op.getAttrList());
|
||||
replacement->setAttrs(wrapped_op.getMutableAttrDict());
|
||||
|
||||
for (auto ops_and_ret_vals :
|
||||
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
|
||||
@ -208,7 +208,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
||||
|
||||
// Create the replacement operation.
|
||||
auto *replacement = builder.createOperation(state);
|
||||
replacement->setAttrs(op.getAttrList());
|
||||
replacement->setAttrs(op.getMutableAttrDict());
|
||||
|
||||
if (auto next_iteration =
|
||||
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
|
||||
|
@ -258,7 +258,8 @@ Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn) {
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
// Mark main function as public, and other functions as private.
|
||||
tf2xla.addPass(
|
||||
@ -277,7 +278,11 @@ Status ConvertMLIRToXlaComputation(
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
|
||||
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
|
||||
for (auto& target_pass : custom_legalization_passes) {
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
|
||||
}
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
||||
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
|
||||
// legalization pass.
|
||||
@ -324,7 +329,8 @@ static Status CompileMlirToXlaHlo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||
|
||||
@ -342,7 +348,8 @@ static Status CompileMlirToXlaHlo(
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, device_type, compilation_result->computation.get(),
|
||||
use_tuple_args,
|
||||
/*return_tuple=*/true, shape_representation_fn));
|
||||
/*return_tuple=*/true, shape_representation_fn,
|
||||
std::move(custom_legalization_passes)));
|
||||
|
||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||
// node.
|
||||
@ -372,7 +379,8 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext mlir_context;
|
||||
mlir::OwningModuleRef mlir_module;
|
||||
@ -381,7 +389,8 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module));
|
||||
return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type,
|
||||
use_tuple_args, shape_representation_fn,
|
||||
compilation_result);
|
||||
compilation_result,
|
||||
std::move(custom_legalization_passes));
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
@ -389,7 +398,8 @@ Status CompileGraphToXlaHlo(
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig config;
|
||||
@ -400,7 +410,8 @@ Status CompileGraphToXlaHlo(
|
||||
|
||||
return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes,
|
||||
device_type, use_tuple_args,
|
||||
shape_representation_fn, compilation_result);
|
||||
shape_representation_fn, compilation_result,
|
||||
std::move(custom_legalization_passes));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
@ -50,11 +51,14 @@ namespace tensorflow {
|
||||
// shape_representation_fn: when this is set, this shape representation function
|
||||
// will be used to determine argument and result shapes. Otherwise the
|
||||
// original shape will be used as is.
|
||||
// custom_legalization_passes: passes to run before the default TF legalization
|
||||
// passes for backend-specific ops.
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr);
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||
|
||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||
// metadata and stores them in CompilationResult.
|
||||
@ -62,7 +66,8 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result);
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||
|
||||
// Same as the above but takes input as TensorFlow Graph.
|
||||
Status CompileGraphToXlaHlo(
|
||||
@ -70,7 +75,8 @@ Status CompileGraphToXlaHlo(
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result);
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -252,6 +252,37 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
||||
::testing::HasSubstr(expected_signature));
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ShapeInferenceAfterLegalization) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) {
|
||||
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>)
|
||||
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<TensorShape> arg_shapes{TensorShape({8, 16, 16, 64}),
|
||||
TensorShape({64})};
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
|
||||
constexpr char expected_signature[] =
|
||||
R"(-> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]))";
|
||||
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
|
||||
::testing::HasSubstr(expected_signature));
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
|
@ -14,6 +14,7 @@ package_group(
|
||||
"//learning/brain/experimental/dtensor/...",
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//learning/brain/google/xla/kernels/...",
|
||||
"//learning/brain/google/xla/mlir/...",
|
||||
"//learning/brain/swift/swift_mlir/...",
|
||||
"//learning/pathways/data_parallel/tf2xla/...",
|
||||
"//platforms/xla/...",
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
@ -810,9 +811,53 @@ OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||
// ConcatenateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class ConcatenateOperandRemoval : public OpRewritePattern<ConcatenateOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ConcatenateOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto axis = op.dimension().getLimitedValue();
|
||||
llvm::SmallVector<Value, 6> new_operands;
|
||||
for (auto operand : op.getOperands()) {
|
||||
auto ty = operand.getType().cast<ShapedType>();
|
||||
if (ty.getDimSize(axis) != 0) {
|
||||
new_operands.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) {
|
||||
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, op.getResult().getType(),
|
||||
new_operands, op.dimension());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConcatenateOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<ConcatenateOperandRemoval>(context);
|
||||
}
|
||||
|
||||
OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getNumOperands() == 1) return getOperand(0);
|
||||
return {};
|
||||
|
||||
ShapedType type = getResult().getType().cast<ShapedType>();
|
||||
if (!type.hasStaticShape()) return {};
|
||||
|
||||
auto axis = dimension().getLimitedValue();
|
||||
llvm::SmallVector<Value, 6> new_operands;
|
||||
for (auto operand : getOperands()) {
|
||||
auto ty = operand.getType().cast<ShapedType>();
|
||||
if (ty.getDimSize(axis) != 0) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(type, ArrayRef<Attribute>());
|
||||
}
|
||||
|
||||
static LogicalResult Verify(ConcatenateOp op) {
|
||||
@ -1381,6 +1426,89 @@ void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand,
|
||||
operand, start_indices, limit_indices, strides);
|
||||
}
|
||||
|
||||
template <typename I, typename E>
|
||||
static void SliceElements(I values, ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> starts, ArrayRef<int64_t> limits,
|
||||
ArrayRef<int64_t> strides,
|
||||
llvm::SmallVectorImpl<E>* out_values) {
|
||||
assert(starts.size() == limits.size());
|
||||
assert(starts.size() == strides.size());
|
||||
if (starts.empty()) return;
|
||||
|
||||
int64_t start = starts.front();
|
||||
int64_t limit = limits.front();
|
||||
int64_t stride = strides.front();
|
||||
if (starts.size() == 1) {
|
||||
for (int i = start; i < limit; i += stride) {
|
||||
out_values->push_back(*(values + i));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (; start < limit; start += stride) {
|
||||
auto begin = values + start * sizes.front();
|
||||
SliceElements<I, E>(begin, sizes.drop_front(), starts.drop_front(),
|
||||
limits.drop_front(), strides.drop_front(), out_values);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename I, typename E>
|
||||
static Attribute FoldSlice(SliceOp* op, I values) {
|
||||
auto start = llvm::to_vector<6>(op->start_indices().getValues<int64_t>());
|
||||
auto limit = llvm::to_vector<6>(op->limit_indices().getValues<int64_t>());
|
||||
auto stride = llvm::to_vector<6>(op->strides().getValues<int64_t>());
|
||||
|
||||
auto result_type = op->operand().getType().cast<ShapedType>();
|
||||
if (!result_type.hasStaticShape()) return {};
|
||||
|
||||
auto shape = result_type.getShape();
|
||||
int64_t count = result_type.getNumElements();
|
||||
// Compute the striding for each dimension.
|
||||
llvm::SmallVector<int64_t, 6> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (auto v : shape) {
|
||||
count = count / v;
|
||||
sizes.push_back(count);
|
||||
}
|
||||
|
||||
llvm::SmallVector<E, 6> out_values;
|
||||
out_values.reserve(result_type.getNumElements());
|
||||
SliceElements<I, E>(values, sizes, start, limit, stride, &out_values);
|
||||
|
||||
return DenseElementsAttr::get(op->getResult().getType().cast<ShapedType>(),
|
||||
out_values);
|
||||
}
|
||||
|
||||
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Check if the SliceOp is a NoOp operation.
|
||||
auto operand_shape = getOperand().getType().cast<ShapedType>().getShape();
|
||||
auto result_type = getResult().getType().cast<ShapedType>();
|
||||
auto result_shape = result_type.getShape();
|
||||
|
||||
if (result_type.hasStaticShape() && (operand_shape == result_shape)) {
|
||||
return getOperand();
|
||||
}
|
||||
|
||||
if (operands.empty() || !operands.front()) return {};
|
||||
|
||||
// Evaluate for statically valued inputs.
|
||||
DenseElementsAttr elements = operands.front().dyn_cast<DenseElementsAttr>();
|
||||
if (!elements) return {};
|
||||
|
||||
auto etype = elements.getType().getElementType();
|
||||
if (etype.isa<IntegerType>()) {
|
||||
return FoldSlice<DenseElementsAttr::IntElementIterator, APInt>(
|
||||
this, elements.getIntValues().begin());
|
||||
} else if (etype.isa<FloatType>()) {
|
||||
return FoldSlice<
|
||||
llvm::mapped_iterator<DenseElementsAttr::IntElementIterator,
|
||||
std::function<APFloat(const APInt&)>>,
|
||||
APFloat>(this, elements.getFloatValues().begin());
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
// Returns output dimension size for slice result for the given arguments.
|
||||
// Returns -1 if arguments are illegal.
|
||||
static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end,
|
||||
|
@ -647,6 +647,8 @@ def HLO_SliceOp: HLO_Op<
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value operand, "
|
||||
"DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
|
||||
@ -845,6 +847,7 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
|
||||
}
|
||||
|
@ -64,14 +64,14 @@ def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
|
||||
|
||||
// Dynamic representation of a shape vector as a tensor.
|
||||
def HLO_DimensionTensor : ShapedContainerType<
|
||||
[Index, AnySignlessInteger],
|
||||
[Index, HLO_Pred, HLO_Int],
|
||||
And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
|
||||
"a 1D tensor of dimensions">;
|
||||
|
||||
// In general, static shaped tensor constraints should be avoided unless
|
||||
// it is for a legacy op which is only correct with static shapes.
|
||||
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
|
||||
AnyFloat, AnySignlessInteger, HLO_Complex]>;
|
||||
AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA on tensors combined type definitions.
|
||||
|
@ -1,5 +1,50 @@
|
||||
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: concatenate_noop
|
||||
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
|
||||
%0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG]]
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_remove_operand
|
||||
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
|
||||
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG0]]
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_bool
|
||||
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
|
||||
|
||||
return %0 : tensor<0xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_int
|
||||
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
|
||||
|
||||
return %0 : tensor<0xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_float
|
||||
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_variable_start
|
||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// CHECK: "xla_hlo.dynamic-slice"
|
||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
@ -31,6 +76,70 @@ func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1:
|
||||
return %2 : tensor<?x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_noop
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
|
||||
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
|
||||
%0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
|
||||
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_fold
|
||||
func @slice_1D_fold() -> tensor<2xi64> {
|
||||
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[7, 9]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_fp
|
||||
func @slice_1D_fp() -> tensor<2xf32> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
|
||||
// CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_strided_fold
|
||||
func @slice_1D_strided_fold() -> tensor<2xi64> {
|
||||
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[7, 10]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold
|
||||
func @slice_2D_fold() -> tensor<2x2xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
// CHECK-SAME: [6, 7],
|
||||
// CHECK-SAME: [10, 11]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
|
||||
return %1 : tensor<2x2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_horizontal
|
||||
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1, 2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
|
||||
return %1 : tensor<1x4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_vertical
|
||||
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
// CHECK-SAME: [2], [6], [10], [14]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
|
||||
return %1 : tensor<4x1xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_identity
|
||||
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
// CHECK: return %arg0
|
||||
|
@ -1749,10 +1749,12 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
|
||||
op.getLoc(),
|
||||
rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
|
||||
|
||||
auto shaped_type = operand.getType().cast<ShapedType>();
|
||||
auto type = operand.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return rewriter.notifyMatchFailure(op, "requires ranked tensor type");
|
||||
auto constant_ones = rewriter.create<BroadcastOp>(
|
||||
op.getLoc(), shaped_type, scalar_one,
|
||||
GetI64ElementsAttr(shaped_type.getShape(), &rewriter));
|
||||
op.getLoc(), type, scalar_one,
|
||||
GetI64ElementsAttr(type.getShape(), &rewriter));
|
||||
|
||||
auto scaled_input = rewriter.create<MulOp>(
|
||||
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
|
||||
|
@ -78,29 +78,53 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
// building valid MLIR using MlirHloBuilder.
|
||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
||||
// all tf2xla kernels.
|
||||
// clang-format off
|
||||
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
||||
TypeID::get<TF::AbsOp>(), TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::Atan2Op>(), TypeID::get<TF::BatchMatMulV2Op>(),
|
||||
TypeID::get<TF::BiasAddOp>(), TypeID::get<TF::BiasAddGradOp>(),
|
||||
TypeID::get<TF::BitwiseAndOp>(), TypeID::get<TF::BitwiseOrOp>(),
|
||||
TypeID::get<TF::BitwiseXorOp>(), TypeID::get<TF::CastOp>(),
|
||||
TypeID::get<TF::ComplexAbsOp>(), TypeID::get<TF::DivNoNanOp>(),
|
||||
TypeID::get<TF::EqualOp>(), TypeID::get<TF::FloorDivOp>(),
|
||||
TypeID::get<TF::FloorModOp>(), TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(), TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::InvOp>(), TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::LeftShiftOp>(), TypeID::get<TF::LessOp>(),
|
||||
TypeID::get<TF::LessEqualOp>(), TypeID::get<TF::LogicalAndOp>(),
|
||||
TypeID::get<TF::LogicalNotOp>(), TypeID::get<TF::LogicalOrOp>(),
|
||||
TypeID::get<TF::LogOp>(), TypeID::get<TF::MatMulOp>(),
|
||||
TypeID::get<TF::MulOp>(), TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NotEqualOp>(), TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RealDivOp>(), TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::SinOp>(), TypeID::get<TF::SelectV2Op>(),
|
||||
TypeID::get<TF::SubOp>(), TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::TransposeOp>(), TypeID::get<TF::TruncateDivOp>(),
|
||||
TypeID::get<TF::TruncateModOp>(), TypeID::get<TF::UnpackOp>(),
|
||||
TypeID::get<TF::XlaDotOp>()};
|
||||
TypeID::get<TF::AbsOp>(),
|
||||
TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::Atan2Op>(),
|
||||
TypeID::get<TF::BatchMatMulV2Op>(),
|
||||
TypeID::get<TF::BiasAddOp>(),
|
||||
TypeID::get<TF::BiasAddGradOp>(),
|
||||
TypeID::get<TF::BitwiseAndOp>(),
|
||||
TypeID::get<TF::BitwiseOrOp>(),
|
||||
TypeID::get<TF::BitwiseXorOp>(),
|
||||
TypeID::get<TF::CastOp>(),
|
||||
TypeID::get<TF::ComplexAbsOp>(),
|
||||
TypeID::get<TF::DivNoNanOp>(),
|
||||
TypeID::get<TF::EqualOp>(),
|
||||
TypeID::get<TF::FloorDivOp>(),
|
||||
TypeID::get<TF::FloorModOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::LeftShiftOp>(),
|
||||
TypeID::get<TF::LessOp>(),
|
||||
TypeID::get<TF::LessEqualOp>(),
|
||||
TypeID::get<TF::LogicalAndOp>(),
|
||||
TypeID::get<TF::LogicalNotOp>(),
|
||||
TypeID::get<TF::LogicalOrOp>(),
|
||||
TypeID::get<TF::LogOp>(),
|
||||
TypeID::get<TF::MatMulOp>(),
|
||||
TypeID::get<TF::MulOp>(),
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::SinOp>(),
|
||||
TypeID::get<TF::SelectV2Op>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
TypeID::get<TF::TruncateDivOp>(),
|
||||
TypeID::get<TF::TruncateModOp>(),
|
||||
TypeID::get<TF::UnpackOp>(),
|
||||
TypeID::get<TF::XlaDotOp>()
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
auto* abstractOp = op->getAbstractOperation();
|
||||
if (!abstractOp) return false;
|
||||
|
@ -223,7 +223,7 @@ Status LhloDialectEmitter::Run() {
|
||||
// The function signature will be composed of:
|
||||
// - one memref for each of the parameters.
|
||||
// - one memref for each other buffer allocation.
|
||||
llvm::SmallVector<NamedAttributeList, 8> args_attrs;
|
||||
llvm::SmallVector<MutableDictionaryAttr, 8> args_attrs;
|
||||
for (const HloInstruction* param : computation->parameter_instructions()) {
|
||||
TF_ASSIGN_OR_RETURN(auto arg_type, ::xla::ConvertShapeToType<MemRefType>(
|
||||
param->shape(), builder_));
|
||||
|
@ -338,12 +338,15 @@ Py_hash_t PyBfloat16_Hash(PyObject* self) {
|
||||
|
||||
// Python type for PyBfloat16 objects.
|
||||
PyTypeObject PyBfloat16_Type = {
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
"bfloat16", // tp_name
|
||||
sizeof(PyBfloat16), // tp_basicsize
|
||||
0, // tp_itemsize
|
||||
nullptr, // tp_dealloc
|
||||
0, // tp_print NOLINT
|
||||
PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16", // tp_name
|
||||
sizeof(PyBfloat16), // tp_basicsize
|
||||
0, // tp_itemsize
|
||||
nullptr, // tp_dealloc
|
||||
#if PY_VERSION_HEX < 0x03080000
|
||||
nullptr, // tp_print
|
||||
#else
|
||||
0, // tp_vectorcall_offset
|
||||
#endif
|
||||
nullptr, // tp_getattr
|
||||
nullptr, // tp_setattr
|
||||
nullptr, // tp_compare / tp_reserved
|
||||
|
@ -20,28 +20,18 @@ from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.compiler.xla.python import xla_client
|
||||
from tensorflow.compiler.xla.python import xla_extension as _xla
|
||||
from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client
|
||||
|
||||
|
||||
class TpuBackend(xla_client.Backend):
|
||||
class TpuBackend(object):
|
||||
"""XLA backend implemented using the Tpu driver API."""
|
||||
|
||||
# Cache the backends to prevent double driver initializations.
|
||||
_local_backend = None
|
||||
|
||||
def __init__(self, client):
|
||||
"""Creates a new TpuBackend.
|
||||
|
||||
Args:
|
||||
client: A _tpu_client.TpuClient object.
|
||||
"""
|
||||
super(TpuBackend, self).__init__('tpu')
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
def create(worker=None, force=False):
|
||||
"""Constructs a Cloud TPU backend."""
|
||||
# `force` == True will skip caching any backends (if applicable) and will
|
||||
# always try to create a new client.
|
||||
if worker is None:
|
||||
@ -56,52 +46,11 @@ class TpuBackend(xla_client.Backend):
|
||||
if worker == 'local':
|
||||
worker = 'local://'
|
||||
if force:
|
||||
return TpuBackend(_tpu_client.TpuClient.Get(worker))
|
||||
return _tpu_client.TpuClient.Get(worker)
|
||||
if TpuBackend._local_backend is None:
|
||||
logging.info('Starting the local TPU driver.')
|
||||
TpuBackend._local_backend = TpuBackend(
|
||||
_tpu_client.TpuClient.Get(worker))
|
||||
TpuBackend._local_backend = _tpu_client.TpuClient.Get(worker)
|
||||
return TpuBackend._local_backend
|
||||
else:
|
||||
# We do not cache for non-local backends.
|
||||
return TpuBackend(_tpu_client.TpuClient.Get(worker))
|
||||
|
||||
def device_count(self):
|
||||
return self.client.device_count()
|
||||
|
||||
def local_device_count(self):
|
||||
return self.client.local_device_count()
|
||||
|
||||
def local_devices(self):
|
||||
return self.client.local_devices()
|
||||
|
||||
def devices(self):
|
||||
return self.client.devices()
|
||||
|
||||
def host_id(self):
|
||||
return self.client.host_id()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
return self.client.buffer_from_pyval(pyval, device)
|
||||
|
||||
def compile(self, c_computation, compile_options=None):
|
||||
compile_options = compile_options or xla_client.CompileOptions()
|
||||
options = _xla.CompileOptions()
|
||||
options.argument_layouts = compile_options.argument_layouts
|
||||
options.parameter_is_tupled_arguments = compile_options.tuple_arguments
|
||||
build_options = options.executable_build_options
|
||||
build_options.num_replicas = compile_options.num_replicas
|
||||
build_options.num_partitions = compile_options.num_partitions
|
||||
if compile_options.result_layout:
|
||||
build_options.result_layout = compile_options.result_layout
|
||||
if compile_options.device_assignment:
|
||||
build_options.device_assignment = compile_options.device_assignment
|
||||
return self.client.compile(c_computation, options)
|
||||
|
||||
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
||||
if num_partitions is not None:
|
||||
return self.client.get_default_device_assignment(num_replicas,
|
||||
num_partitions)
|
||||
else:
|
||||
# TODO(henrytan): delete this case after all callers can handle 2D output
|
||||
return self.client.get_default_device_assignment(num_replicas)
|
||||
return _tpu_client.TpuClient.Get(worker)
|
||||
|
@ -19,7 +19,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
import inspect
|
||||
@ -47,114 +46,6 @@ ops = _xla.ops
|
||||
profiler = _xla.profiler
|
||||
|
||||
|
||||
class Backend(object, metaclass=abc.ABCMeta):
|
||||
"""Abstract base class for XLA backends."""
|
||||
|
||||
def __init__(self, platform):
|
||||
"""Creates a new Backend.
|
||||
|
||||
Args:
|
||||
platform: A string naming the platform; for example 'gpu'.
|
||||
"""
|
||||
self.platform = platform
|
||||
|
||||
@abc.abstractmethod
|
||||
def device_count(self):
|
||||
"""Returns the number of devices known to the backend."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def local_device_count(self):
|
||||
"""Returns the number of devices local to this host."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def devices(self):
|
||||
"""Returns a list of `device_count()` Device subclasses."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def host_id(self):
|
||||
"""Returns the integer ID of this host."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
"""Allocates a fresh buffer and populates it with `pyval`."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, computation, compile_options=None):
|
||||
"""Compiles a computation. Returns an executable."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_default_device_assignment(self, num_replicas, num_partitions):
|
||||
"""Returns the default device assignment that `compile` would use.
|
||||
|
||||
If `compile_options.device_assignment` isn't set, `compile` will pick a
|
||||
deterministic device assignment based on the number of replicas and
|
||||
partitions, possibly optimizing for device locality. This method returns
|
||||
that assignment, which is useful for e.g. manually replicating a value
|
||||
before passing it to a compiled executable.
|
||||
|
||||
Args:
|
||||
num_replicas: the number of replicas needed.
|
||||
num_partitions: the number of partitions needed.
|
||||
|
||||
Returns:
|
||||
A list of list of Devices of size `(num_replicas, num_partitions)`.
|
||||
"""
|
||||
|
||||
|
||||
class LocalBackend(Backend):
|
||||
"""XLA backend implemented using the in-process xla::LocalClient API."""
|
||||
|
||||
def __init__(self, platform, client):
|
||||
"""Creates a new LocalBackend.
|
||||
|
||||
Args:
|
||||
platform: A string; the user-visible platform name, e.g. 'gpu'.
|
||||
client: An _xla.PyLocalClient object.
|
||||
"""
|
||||
super(LocalBackend, self).__init__(platform)
|
||||
self.client = client
|
||||
|
||||
def device_count(self):
|
||||
return self.client.device_count()
|
||||
|
||||
def local_device_count(self):
|
||||
return self.client.local_device_count()
|
||||
|
||||
def devices(self):
|
||||
return self.client.devices()
|
||||
|
||||
def local_devices(self):
|
||||
return self.client.local_devices()
|
||||
|
||||
def host_id(self):
|
||||
return self.client.host_id()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
return self.client.buffer_from_pyval(pyval, device, force_copy)
|
||||
|
||||
def compile(self, c_computation, compile_options=None):
|
||||
compile_options = compile_options or CompileOptions()
|
||||
options = _xla.CompileOptions()
|
||||
options.argument_layouts = compile_options.argument_layouts
|
||||
options.parameter_is_tupled_arguments = compile_options.tuple_arguments
|
||||
build_options = options.executable_build_options
|
||||
build_options.num_replicas = compile_options.num_replicas
|
||||
build_options.num_partitions = compile_options.num_partitions
|
||||
if compile_options.result_layout:
|
||||
build_options.result_layout = compile_options.result_layout
|
||||
if compile_options.device_assignment:
|
||||
build_options.device_assignment = compile_options.device_assignment
|
||||
return self.client.compile(c_computation, options)
|
||||
|
||||
def get_default_device_assignment(self, num_replicas, num_partitions=None):
|
||||
if num_partitions is not None:
|
||||
return self.client.get_default_device_assignment(num_replicas,
|
||||
num_partitions)
|
||||
else:
|
||||
# TODO(skye): delete this case after all callers can handle 2D output
|
||||
return self.client.get_default_device_assignment(num_replicas)
|
||||
|
||||
|
||||
xla_platform_names = {
|
||||
'cpu': 'Host',
|
||||
'gpu': 'CUDA',
|
||||
@ -162,8 +53,7 @@ xla_platform_names = {
|
||||
|
||||
|
||||
def _cpu_backend_factory():
|
||||
client = _xla.get_cpu_client(asynchronous=True)
|
||||
return LocalBackend(platform='cpu', client=client)
|
||||
return _xla.get_cpu_client(asynchronous=True)
|
||||
|
||||
|
||||
def _gpu_backend_factory(distributed_client=None, node_id=0):
|
||||
@ -186,12 +76,11 @@ def _gpu_backend_factory(distributed_client=None, node_id=0):
|
||||
config.memory_fraction = float(memory_fraction)
|
||||
config.preallocate = preallocate not in ('0', 'false', 'False')
|
||||
|
||||
client = _xla.get_nvidia_gpu_client(
|
||||
return _xla.get_nvidia_gpu_client(
|
||||
asynchronous=True,
|
||||
allocator_config=config,
|
||||
distributed_client=distributed_client,
|
||||
node_id=node_id)
|
||||
return LocalBackend(platform='gpu', client=client)
|
||||
|
||||
|
||||
# Backend factories, keyed by user-visible name, in increasing priority order.
|
||||
@ -480,29 +369,7 @@ def computation_count():
|
||||
"""
|
||||
|
||||
Device = _xla.Device
|
||||
|
||||
|
||||
class CompileOptions(object):
|
||||
"""Python object for XLA compile options.
|
||||
|
||||
These options can be passed to the 'compile' step when using a local XLA
|
||||
client.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.executable_build_options = _xla.ExecutableBuildOptions()
|
||||
self.xla_dump_to = None
|
||||
self.dump_hlo_pass_re = None
|
||||
self.dump_hlo_module_re = None
|
||||
self.dump_hlo_as_text = None
|
||||
self.dump_hlo_as_proto = None
|
||||
self.hlo_profile = None
|
||||
self.num_replicas = 1
|
||||
self.num_partitions = 1
|
||||
self.argument_layouts = None
|
||||
self.result_layout = None
|
||||
self.device_assignment = None
|
||||
self.tuple_arguments = False
|
||||
CompileOptions = _xla.CompileOptions
|
||||
|
||||
|
||||
# An Executable is a C++ class that duck types with the following API:
|
||||
|
@ -1942,7 +1942,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
del buffer # Free "buffer" to make sure dlt retains ownership.
|
||||
self.assertEqual(type(dlt).__name__, "PyCapsule")
|
||||
y = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlt, self.backend.client)
|
||||
dlt, self.backend)
|
||||
np.testing.assert_array_equal(x, y.to_py())
|
||||
|
||||
def testTensorsCanBeConsumedOnceOnly(self):
|
||||
@ -1952,7 +1952,7 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
|
||||
def ConsumeDLPackTensor():
|
||||
_ = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlt, self.backend.client)
|
||||
dlt, self.backend)
|
||||
|
||||
ConsumeDLPackTensor()
|
||||
self.assertRaisesRegex(
|
||||
|
@ -68,8 +68,8 @@ class HloPassFix : public Pass {
|
||||
VLOG(3) << "changed_this_iteration: " << changed_this_iteration;
|
||||
++iteration_count;
|
||||
if (iteration_count == kLimit) {
|
||||
LOG(WARNING) << "Unexpectedly high number of iterations in HLO passes, "
|
||||
"exiting fixed point loop.";
|
||||
VLOG(1) << "Unexpectedly high number of iterations in HLO passes, "
|
||||
"exiting fixed point loop.";
|
||||
// Return false in case this is fixed point is nested.
|
||||
return false;
|
||||
}
|
||||
|
@ -22,6 +22,76 @@ namespace {
|
||||
// Define a dummy chunk for chunks that will be allocated in the default memory
|
||||
// space and for keeping track of number of asynchronous copies.
|
||||
const HeapSimulator::Chunk kDummyChunk{-1, -1};
|
||||
|
||||
// Returns a heuristic value that captures how much putting this tensor to
|
||||
// the alternate memory would help if the op is memory bound, or otherwise
|
||||
// how far off is the op to memory boundedness. The larger this number, the
|
||||
// higher priority it will be placed in the alternate memory.
|
||||
float GetAlternateMemoryBenefit(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem) {
|
||||
float elapsed_time_due_to_compute =
|
||||
cost_analysis.GetInstructionElapsedDueToCompute(instruction);
|
||||
float elapsed_time_due_to_memory =
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(instruction);
|
||||
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
|
||||
// Memory bound, return how much alternate memory is better.
|
||||
return elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem;
|
||||
} else {
|
||||
// Compute bound, return how far off are we to memory boundedness.
|
||||
return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a heuristic value of memory boundedness for the given BufferInterval.
|
||||
// The larger this number, the higher priority it will be placed in the
|
||||
// alternate memory.
|
||||
float GetMemoryBoundedness(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) {
|
||||
const HloInstruction& defining_instruction =
|
||||
*interval.buffer->defining_instruction();
|
||||
float alternate_mem_benefit =
|
||||
GetAlternateMemoryBenefit(cost_analysis, defining_instruction,
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(
|
||||
defining_instruction,
|
||||
/*operand_in_alternate_mem=*/{},
|
||||
/*output_in_alternate_mem=*/true));
|
||||
for (const HloUse& use : interval.buffer->uses()) {
|
||||
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
|
||||
cost_analysis, *use.instruction,
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(*use.instruction,
|
||||
use.operand_number));
|
||||
// If the benefit is positive (memory bound), add it to this buffer's
|
||||
// benefit. If the benefit is negative (compute bound), calculate the
|
||||
// maximum.
|
||||
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
|
||||
alternate_mem_benefit += use_alternate_mem_benefit;
|
||||
} else {
|
||||
alternate_mem_benefit =
|
||||
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
|
||||
}
|
||||
}
|
||||
|
||||
// Get performance slowdown in seconds of prefetching current BufferInterval
|
||||
// causing to other BufferIntervals.
|
||||
float alternate_mem_slowdown =
|
||||
cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
||||
|
||||
// Scale the slowdown based on the time of this buffer. We would want earlier
|
||||
// buffers have lower slowdown values, because they are less likely to overlap
|
||||
// with other HLOs.
|
||||
// TODO(yuemmawang): We may want a piecewise function, where a lower slowdown
|
||||
// for early HLOs, and full slowdown for mid-to-late HLOs.
|
||||
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
|
||||
// more HLOs have higher slowdown, and vice versa.
|
||||
float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime();
|
||||
alternate_mem_slowdown *= scale;
|
||||
|
||||
return alternate_mem_benefit - alternate_mem_slowdown;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
|
||||
@ -255,6 +325,12 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
|
||||
", logical interval elapsed (s) = ", logical_interval_elapsed);
|
||||
}
|
||||
|
||||
absl::optional<float>
|
||||
CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
return GetMemoryBoundedness(cost_analysis_, interval);
|
||||
}
|
||||
|
||||
std::string MemorySpaceAssignment::AllocationValue::ToString() const {
|
||||
std::string out = absl::StrCat("computation = ", computation()->name());
|
||||
absl::StrAppend(&out, "\n position:\n");
|
||||
@ -495,6 +571,86 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
return true;
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
std::string* debug_str) const {
|
||||
// Columns in buffer information:
|
||||
// buffer_id: int. This value can be used to match the allocation in
|
||||
// allocation information.
|
||||
// buffer_name: string.
|
||||
// alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
|
||||
// thought it would be beneficial to put this in the alternate memory. The
|
||||
// higher the value, the more it is memory bound.
|
||||
// size: int. In bytes.
|
||||
// definition_time: int. Logical time this value was defined in the schedule.
|
||||
// use_times: string. This is a semicolon-separated list of integers for all
|
||||
// the use times.
|
||||
if (debug_str->empty()) {
|
||||
// Append the column names.
|
||||
absl::StrAppend(debug_str,
|
||||
"buffer_id,buffer_name,alt_mem_benefit,size,definition_"
|
||||
"time,use_times\n");
|
||||
}
|
||||
const HloBuffer& buffer =
|
||||
alias_analysis_.GetBufferContainingValue(*interval.buffer);
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
int64 definition_time =
|
||||
instruction_schedule.at(interval.buffer->defining_position().instruction);
|
||||
std::set<int64> use_times;
|
||||
for (const HloValue* value : buffer.values()) {
|
||||
for (const HloUse& use : value->uses()) {
|
||||
use_times.insert(instruction_schedule.at(use.instruction));
|
||||
}
|
||||
}
|
||||
|
||||
absl::StrAppend(debug_str, buffer.id(), ",");
|
||||
absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
|
||||
auto alternate_memory_benefit =
|
||||
options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
|
||||
interval);
|
||||
absl::StrAppend(
|
||||
debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
|
||||
absl::StrAppend(debug_str, interval.size, ",");
|
||||
absl::StrAppend(debug_str, definition_time, ",");
|
||||
absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\"");
|
||||
absl::StrAppend(debug_str, "\n");
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval,
|
||||
const MemorySpaceAssignment::Allocation& allocation,
|
||||
std::string* debug_str) const {
|
||||
// Columns in allocation information:
|
||||
// buffer_id: int. This value can be used the match with buffer info.
|
||||
// size: int. In bytes.
|
||||
// offset: int. In bytes.
|
||||
// start_time: int. Logical start time of the allocation.
|
||||
// end_time: int. Logical end time of the allocation.
|
||||
if (debug_str->empty()) {
|
||||
// Append the column names.
|
||||
absl::StrAppend(debug_str, "buffer_id,size,offset,start_time,end_time\n");
|
||||
}
|
||||
if (allocation.memory_space() == MemorySpace::kAlternate) {
|
||||
const HloBuffer& buffer =
|
||||
alias_analysis_.GetBufferContainingValue(*interval.buffer);
|
||||
absl::StrAppend(debug_str, buffer.id(), ",");
|
||||
absl::StrAppend(debug_str, interval.size, ",");
|
||||
absl::StrAppend(debug_str, allocation.chunk().offset, ",");
|
||||
absl::StrAppend(debug_str, allocation.start_time(), ",");
|
||||
absl::StrAppend(debug_str, allocation.end_time(), "\n");
|
||||
}
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::DumpIfEnabled(
|
||||
absl::string_view buffer_info_str,
|
||||
absl::string_view allocation_info_str) const {
|
||||
if (!options_.dump_fn) {
|
||||
return;
|
||||
}
|
||||
options_.dump_fn("bufferinfo", buffer_info_str);
|
||||
options_.dump_fn("allocinfo", allocation_info_str);
|
||||
}
|
||||
|
||||
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
std::vector<BufferInterval> sorted_buffer_intervals =
|
||||
GetSortedBufferIntervals();
|
||||
@ -514,6 +670,9 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
}
|
||||
}
|
||||
|
||||
std::string buffer_info_str;
|
||||
std::string allocation_info_str;
|
||||
|
||||
for (auto& interval : sorted_buffer_intervals) {
|
||||
if (!interval.need_allocation) {
|
||||
continue;
|
||||
@ -616,6 +775,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
}
|
||||
}
|
||||
|
||||
AppendBufferInfoDebugString(interval, &buffer_info_str);
|
||||
|
||||
// Data structure to contain the preferred offset for a given computation.
|
||||
// We ensure that the same offset will be allocated outside the while loop
|
||||
// as well as inside the while loop.
|
||||
@ -743,6 +904,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
if (allocation_success) {
|
||||
for (AllocationValue& allocation_value : allocation_values) {
|
||||
for (auto& allocation : *allocation_value.allocation_sequence()) {
|
||||
AppendAllocationInfoDebugString(interval, *allocation,
|
||||
&allocation_info_str);
|
||||
allocations_->push_back(std::move(allocation));
|
||||
}
|
||||
}
|
||||
@ -752,6 +915,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
pending_async_copies_.clear();
|
||||
}
|
||||
|
||||
VLOG(3) << "Debug buffer info: ";
|
||||
VLOG(3) << buffer_info_str;
|
||||
VLOG(3) << "Debug allocation info: ";
|
||||
VLOG(3) << allocation_info_str;
|
||||
DumpIfEnabled(buffer_info_str, allocation_info_str);
|
||||
|
||||
return result_;
|
||||
}
|
||||
|
||||
@ -1544,70 +1713,8 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
|
||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
|
||||
return [&](const BufferInterval& x, const BufferInterval& y) {
|
||||
// Returns a heuristic value that captures how much putting this tensor to
|
||||
// the alternate memory would help if the op is memory bound, or otherwise
|
||||
// how far off is the op to memory boundedness. The larger this number, the
|
||||
// higher priority it will be placed in the alternate memory.
|
||||
auto get_alternate_mem_benefit =
|
||||
[&](const HloInstruction& instruction,
|
||||
float elapsed_time_due_to_alternate_mem) {
|
||||
float elapsed_time_due_to_compute =
|
||||
cost_analysis.GetInstructionElapsedDueToCompute(instruction);
|
||||
float elapsed_time_due_to_memory =
|
||||
cost_analysis.GetInstructionElapsedDueToMemory(instruction);
|
||||
if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
|
||||
// Memory bound, return how much alternate memory is better.
|
||||
return elapsed_time_due_to_memory -
|
||||
elapsed_time_due_to_alternate_mem;
|
||||
} else {
|
||||
// Compute bound, return how far off are we to memory boundedness.
|
||||
return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
|
||||
}
|
||||
};
|
||||
|
||||
auto get_memory_boundedness = [&](const BufferInterval& interval) {
|
||||
const HloInstruction& defining_instruction =
|
||||
*interval.buffer->defining_instruction();
|
||||
float alternate_mem_benefit = get_alternate_mem_benefit(
|
||||
defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory(
|
||||
defining_instruction,
|
||||
/*operand_in_alternate_mem=*/{},
|
||||
/*output_in_alternate_mem=*/true));
|
||||
for (const HloUse& use : interval.buffer->uses()) {
|
||||
float use_alternate_mem_benefit = get_alternate_mem_benefit(
|
||||
*use.instruction, cost_analysis.GetInstructionElapsedDueToMemory(
|
||||
*use.instruction, use.operand_number));
|
||||
// If the benefit is positive (memory bound), add it to this buffer's
|
||||
// benefit. If the benefit is negative (compute bound), calculate the
|
||||
// maximum.
|
||||
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
|
||||
alternate_mem_benefit += use_alternate_mem_benefit;
|
||||
} else {
|
||||
alternate_mem_benefit =
|
||||
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
|
||||
}
|
||||
}
|
||||
|
||||
// Get performance slowdown in seconds of prefetching current
|
||||
// BufferInterval causing to other BufferIntervals.
|
||||
float alternate_mem_slowdown =
|
||||
cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
||||
|
||||
// Scale the slowdown based on the time of this buffer. We would want
|
||||
// earlier buffers have lower slowdown values, because they are less
|
||||
// likely to overlap with other HLOs.
|
||||
// TODO (yuemmawang) We may want a piecewise function, where a lower
|
||||
// slowdown for early HLOs, and full slowdown for mid-to-late HLOs.
|
||||
// TODO (yuemmawang) Further in a smarter way, we want buffers overlapped
|
||||
// with more HLOs have higher slowdown, and vice versa.
|
||||
float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime();
|
||||
alternate_mem_slowdown *= scale;
|
||||
|
||||
return alternate_mem_benefit - alternate_mem_slowdown;
|
||||
};
|
||||
|
||||
float x_memory_boundedness = get_memory_boundedness(x);
|
||||
float y_memory_boundedness = get_memory_boundedness(y);
|
||||
float x_memory_boundedness = GetMemoryBoundedness(cost_analysis, x);
|
||||
float y_memory_boundedness = GetMemoryBoundedness(cost_analysis, y);
|
||||
if (x_memory_boundedness != y_memory_boundedness) {
|
||||
return x_memory_boundedness > y_memory_boundedness;
|
||||
}
|
||||
|
@ -63,9 +63,15 @@ class PresetAssignments {
|
||||
return assignment_info_;
|
||||
}
|
||||
|
||||
// Get debugging information.
|
||||
std::string buffer_info_str() const { return buffer_info_str_; }
|
||||
std::string allocation_info_str() const { return allocation_info_str_; }
|
||||
|
||||
private:
|
||||
std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_;
|
||||
std::vector<std::pair<int64, AssignmentInformation>> assignment_info_;
|
||||
std::string buffer_info_str_;
|
||||
std::string allocation_info_str_;
|
||||
};
|
||||
|
||||
// A wrapper class around HloCostAnalysis with additional knowledge about the
|
||||
@ -165,6 +171,14 @@ class PrefetchIntervalPicker {
|
||||
virtual std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
|
||||
int64 end_time) const = 0;
|
||||
|
||||
// Prefetch interval pickers may return a value corresponding to the benefit
|
||||
// of placing the BufferInterval in the alternate memory. The larger value,
|
||||
// the more beneficial.
|
||||
virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
protected:
|
||||
const absl::flat_hash_map<const HloInstruction*, int64>*
|
||||
instruction_schedule_ = nullptr;
|
||||
@ -239,6 +253,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
|
||||
int64 end_time) const override;
|
||||
|
||||
absl::optional<float> BufferIntervalAlternateMemoryBenefit(
|
||||
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval)
|
||||
const override;
|
||||
|
||||
private:
|
||||
// Returns the elapsed time in seconds between the logical interval that
|
||||
// corresponds to the instruction schedule.
|
||||
@ -317,6 +335,11 @@ class MemorySpaceAssignment {
|
||||
// buffers.
|
||||
bool verify = false;
|
||||
|
||||
// If not nullptr, this function is called to dump debugging information.
|
||||
// The first argument is appended to the file name and the second argument
|
||||
// is the contents of the file.
|
||||
std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr;
|
||||
|
||||
// Enable prefetching buffers into preferred memory across program
|
||||
// boundaries
|
||||
bool enable_cross_program_prefetch = true;
|
||||
@ -899,6 +922,17 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
// buffers from the interval trees.
|
||||
void UncommitPendingChunks();
|
||||
|
||||
// Append buffer and allocation infos for debugging and dump it into a file,
|
||||
// if enabled.
|
||||
void AppendBufferInfoDebugString(const BufferInterval& interval,
|
||||
std::string* debug_str) const;
|
||||
void AppendAllocationInfoDebugString(
|
||||
const BufferInterval& interval,
|
||||
const MemorySpaceAssignment::Allocation& allocation,
|
||||
std::string* debug_str) const;
|
||||
void DumpIfEnabled(absl::string_view buffer_info_str,
|
||||
absl::string_view allocation_info_str) const;
|
||||
|
||||
// Returns the available heap size in the alternate memory.
|
||||
int64 available_heap_size() const {
|
||||
return options_.max_size_in_bytes - reserved_in_bytes_;
|
||||
|
@ -32,15 +32,20 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
StatusOr<XlaOp> GetPhiloxStateOp(XlaOp input_state) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* shape,
|
||||
input_state.builder()->GetShapePtr(input_state));
|
||||
if (shape->dimensions(0) >= 3) {
|
||||
XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) {
|
||||
if (state_shape.dimensions(0) >= 3) {
|
||||
return Slice(input_state, {1}, {3}, {1});
|
||||
}
|
||||
return Rev(input_state, {0});
|
||||
}
|
||||
|
||||
XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) {
|
||||
if (state_shape.dimensions(0) < 3) {
|
||||
output_state = Slice(output_state, {0}, {1}, {1});
|
||||
}
|
||||
return output_state;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RngBitGeneratorExpander::InstructionMatchesPattern(
|
||||
@ -60,25 +65,22 @@ StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
|
||||
XlaBuilder builder("rng");
|
||||
XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
|
||||
XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
|
||||
XlaOp state_op;
|
||||
|
||||
BitGeneratorTy generator = nullptr;
|
||||
RngOutput output;
|
||||
switch (algorithm) {
|
||||
case RandomAlgorithm::RNG_THREE_FRY:
|
||||
generator = ThreeFryBitGenerator;
|
||||
state_op = Slice(state_param, {1}, {2}, {1});
|
||||
output = ThreeFryBitGenerator(key_op, Slice(state_param, {1}, {2}, {1}),
|
||||
data_shape);
|
||||
break;
|
||||
case RandomAlgorithm::RNG_PHILOX: {
|
||||
generator = PhiloxBitGenerator;
|
||||
TF_ASSIGN_OR_RETURN(state_op, GetPhiloxStateOp(state_param));
|
||||
case RandomAlgorithm::RNG_PHILOX:
|
||||
output = PhiloxBitGenerator(
|
||||
key_op, GetPhiloxStateOp(state_param, state_shape), data_shape);
|
||||
output.state = GetPhiloxOutputStateOp(output.state, state_shape);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Unsupported random algorthm: %s",
|
||||
RandomAlgorithm_Name(algorithm));
|
||||
}
|
||||
|
||||
RngOutput output = generator(key_op, state_op, data_shape);
|
||||
XlaOp final_state =
|
||||
ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
|
||||
Tuple(&builder, {final_state, output.value});
|
||||
|
@ -27,7 +27,7 @@ func @main(%arg0 : tensor<10xf32>, %arg1 : tensor<10xf32>) -> tensor<10x10xf32>
|
||||
|
||||
@tf.function
|
||||
def foo(x, y):
|
||||
return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
|
||||
return mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
|
||||
|
||||
graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
|
||||
```
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "ShuffleAndRepeatDatasetV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -165,6 +165,12 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
|
||||
void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame,
|
||||
DoneCallback done) override;
|
||||
|
||||
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) override;
|
||||
|
||||
Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* frame) override;
|
||||
|
||||
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) override;
|
||||
|
||||
@ -235,6 +241,17 @@ void FunctionLibraryRuntimeOverlay::Run(const Options& opts, Handle handle,
|
||||
base_flr_->Run(opts, handle, call_frame, std::move(done));
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) {
|
||||
return base_flr_->RunSync(std::move(opts), handle, args, rets);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) {
|
||||
return base_flr_->RunSync(std::move(opts), handle, call_frame);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::CreateKernel(
|
||||
const std::shared_ptr<const NodeProperties>&, OpKernel**) {
|
||||
// We don't have access to base_lib_def_ in base function library runtime (aka
|
||||
@ -331,6 +348,10 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
std::vector<Tensor>* rets, DoneCallback done) override;
|
||||
void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
|
||||
DoneCallback done) override;
|
||||
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) override;
|
||||
Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) override;
|
||||
|
||||
bool IsStateful(const string& function) const override;
|
||||
|
||||
@ -424,6 +445,10 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||
Item* item, DoneCallback done);
|
||||
|
||||
Status PrepareRunSync(
|
||||
Handle handle, Options* run_opts, Item** out_item,
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous);
|
||||
|
||||
void ExecutorArgsFromOptions(const FunctionLibraryRuntime::Options& run_opts,
|
||||
CallFrameInterface* frame,
|
||||
Executor::Args* exec_args);
|
||||
@ -1187,6 +1212,79 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||
item->exec->RunAsync(exec_args, std::move(done));
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::PrepareRunSync(
|
||||
Handle handle, Options* run_opts, Item** out_item,
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous>* out_rendezvous) {
|
||||
if (run_opts->cancellation_manager &&
|
||||
run_opts->cancellation_manager->IsCancelled()) {
|
||||
return errors::Cancelled("");
|
||||
}
|
||||
|
||||
if (run_opts->remote_execution) {
|
||||
// NOTE(mrry): This bit is only set for a local function when `parent_`
|
||||
// calls back into this class, and the current implementation of
|
||||
// `ProcessFunctionLibraryRuntime` currently always uses the asynchronous
|
||||
// Run() method.
|
||||
return errors::Unimplemented("Remote calling with RunSync()");
|
||||
}
|
||||
|
||||
if (run_opts->create_rendezvous) {
|
||||
*out_rendezvous =
|
||||
absl::make_unique<PrivateIntraProcessRendezvous>(device_mgr_);
|
||||
run_opts->rendezvous = out_rendezvous->get();
|
||||
run_opts->create_rendezvous = false;
|
||||
}
|
||||
|
||||
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
|
||||
if (local_handle == kInvalidLocalHandle) {
|
||||
*out_item = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(GetOrCreateItem(local_handle, out_item));
|
||||
|
||||
if (run_opts->runner == nullptr) {
|
||||
run_opts->runner = &default_runner_;
|
||||
}
|
||||
DCHECK(run_opts->runner != nullptr);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) {
|
||||
Item* item = nullptr;
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
|
||||
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
|
||||
if (item == nullptr) {
|
||||
return parent_->RunSync(opts, handle, args, rets);
|
||||
}
|
||||
|
||||
Executor::Args exec_args;
|
||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||
FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
|
||||
TF_RETURN_IF_ERROR(frame.SetArgs(args));
|
||||
ExecutorArgsFromOptions(opts, &frame, &exec_args);
|
||||
|
||||
TF_RETURN_IF_ERROR(item->exec->Run(exec_args));
|
||||
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) {
|
||||
Item* item = nullptr;
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
|
||||
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
|
||||
if (item == nullptr) {
|
||||
return parent_->RunSync(opts, handle, call_frame);
|
||||
}
|
||||
|
||||
Executor::Args exec_args;
|
||||
ExecutorArgsFromOptions(opts, call_frame, &exec_args);
|
||||
return item->exec->Run(exec_args);
|
||||
}
|
||||
|
||||
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
|
||||
const OpDef* op_def;
|
||||
const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
|
||||
|
@ -1488,6 +1488,33 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
});
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) const {
|
||||
Notification n;
|
||||
Status s;
|
||||
Run(opts, handle, args, rets, [&n, &s](const Status& status) {
|
||||
s.Update(status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return s;
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
|
||||
Notification n;
|
||||
Status s;
|
||||
Run(opts, handle, frame, [&n, &s](const Status& status) {
|
||||
s.Update(status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return s;
|
||||
}
|
||||
|
||||
void ProcessFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||
|
@ -194,6 +194,13 @@ class ProcessFunctionLibraryRuntime {
|
||||
const FunctionArgsInterface& args, std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const;
|
||||
|
||||
Status RunSync(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const;
|
||||
Status RunSync(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
CallFrameInterface* frame) const;
|
||||
|
||||
const DeviceMgr* device_mgr() { return device_mgr_; }
|
||||
|
||||
const std::shared_ptr<DeviceSet> device_set() {
|
||||
|
@ -290,12 +290,34 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "data_service",
|
||||
srcs = ["data_service.cc"],
|
||||
hdrs = [
|
||||
"data_service.h",
|
||||
],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
":grpc_util",
|
||||
":master_cc_grpc_proto",
|
||||
":master_proto_cc",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "data_service_test",
|
||||
srcs = ["data_service_test.cc"],
|
||||
tags = ["no_windows"],
|
||||
deps = [
|
||||
":compression_utils",
|
||||
":data_service",
|
||||
":grpc_master_impl",
|
||||
":grpc_util",
|
||||
":grpc_worker_impl",
|
||||
|
140
tensorflow/core/data/service/data_service.cc
Normal file
140
tensorflow/core/data/service/data_service.cc
Normal file
@ -0,0 +1,140 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||
ProcessingMode processing_mode,
|
||||
int64* job_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
CreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
req.set_processing_mode(ProcessingModeDef(processing_mode));
|
||||
CreateJobResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status status = stub_->CreateJob(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
return grpc_util::WrapError(
|
||||
absl::StrCat("Failed to create job for dataset with id ", dataset_id),
|
||||
status);
|
||||
}
|
||||
*job_id = resp.job_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
||||
int64* dataset_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = dataset;
|
||||
GetOrRegisterDatasetResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
return grpc_util::WrapError("Failed to register dataset", status);
|
||||
}
|
||||
*dataset_id = resp.dataset_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::GetTasks(int64 job_id,
|
||||
std::vector<TaskInfo>* tasks,
|
||||
bool* job_finished) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetTasksRequest req;
|
||||
req.set_job_id(job_id);
|
||||
GetTasksResponse resp;
|
||||
grpc_impl::ClientContext ctx;
|
||||
grpc::Status s = stub_->GetTasks(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get tasks", s);
|
||||
}
|
||||
tasks->clear();
|
||||
for (auto& task : resp.task_info()) {
|
||||
tasks->push_back(task);
|
||||
}
|
||||
*job_finished = resp.job_finished();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceMasterClient::EnsureInitialized() {
|
||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
auto channel = grpc::CreateChannel(address_, credentials);
|
||||
stub_ = MasterService::NewStub(channel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerClient::GetElement(int64 task_id,
|
||||
CompressedElement* element,
|
||||
bool* end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
GetElementResponse resp;
|
||||
grpc_impl::ClientContext ctx;
|
||||
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get element", s);
|
||||
}
|
||||
*end_of_sequence = resp.end_of_sequence();
|
||||
if (!*end_of_sequence) {
|
||||
*element = std::move(*resp.mutable_compressed_element());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DataServiceWorkerClient::EnsureInitialized() {
|
||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(-1);
|
||||
auto channel = grpc::CreateCustomChannel(address_, credentials, args);
|
||||
stub_ = WorkerService::NewStub(channel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateDataServiceMasterClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
std::unique_ptr<DataServiceMasterClient>* out) {
|
||||
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
||||
TF_RETURN_IF_ERROR(client->Initialize());
|
||||
*out = std::move(client);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateDataServiceWorkerClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
std::unique_ptr<DataServiceWorkerClient>* out) {
|
||||
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
|
||||
TF_RETURN_IF_ERROR(client->Initialize());
|
||||
*out = std::move(client);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
123
tensorflow/core/data/service/data_service.h
Normal file
123
tensorflow/core/data/service/data_service.h
Normal file
@ -0,0 +1,123 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// Modes for how a tf.data service job should process a dataset.
|
||||
enum class ProcessingMode : int64 {
|
||||
// Each tf.data worker processes an entire epoch. If a dataset contains 2
|
||||
// elements and there are 3 workers, the job will produce 6 elements.
|
||||
PARALLEL_EPOCHS = 0,
|
||||
// Processing of a single epoch is distributed across all tf.data workers.
|
||||
ONE_EPOCH = 1,
|
||||
};
|
||||
|
||||
// Base class for data service clients. Data service clients are
|
||||
// thread-compatible, requiring external synchronization when used from multiple
|
||||
// threads.
|
||||
class DataServiceClientBase {
|
||||
public:
|
||||
DataServiceClientBase(absl::string_view address, absl::string_view protocol)
|
||||
: address_(address), protocol_(protocol) {}
|
||||
|
||||
virtual ~DataServiceClientBase() = default;
|
||||
// Not copyable or movable.
|
||||
DataServiceClientBase(const DataServiceClientBase&) = delete;
|
||||
DataServiceClientBase& operator=(const DataServiceClientBase&) = delete;
|
||||
|
||||
// Initializes the client. Calling `Initialize()` is not required since the
|
||||
// first RPC will perform any necessary initialization. However, it can be
|
||||
// useful to call `Initialize()` proactively so that any errors that happen
|
||||
// during initialization can be surfaced earlier.
|
||||
Status Initialize() { return EnsureInitialized(); }
|
||||
|
||||
protected:
|
||||
// Initializes the client if it isn't already initialized.
|
||||
virtual Status EnsureInitialized() = 0;
|
||||
|
||||
const std::string address_;
|
||||
const std::string protocol_;
|
||||
};
|
||||
|
||||
// Client for communicating with the tf.data service master.
|
||||
class DataServiceMasterClient : public DataServiceClientBase {
|
||||
public:
|
||||
DataServiceMasterClient(absl::string_view address, absl::string_view protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
|
||||
// Registers a dataset with the tf.data service, and stores the generated
|
||||
// dataset id in `*dataset_id`.
|
||||
Status RegisterDataset(GraphDef dataset, int64* dataset_id);
|
||||
|
||||
// Creates a new tf.data service job for the specified dataset. The id for the
|
||||
// created job will be stored in `*job_id`.
|
||||
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
int64* job_id);
|
||||
|
||||
// Queries the master for the tasks associated with the specified job.
|
||||
// The tasks will be stored in *tasks, and whether the job is finished will
|
||||
// be stored in `*job_finished`.
|
||||
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
|
||||
bool* job_finished);
|
||||
|
||||
protected:
|
||||
Status EnsureInitialized() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<MasterService::Stub> stub_;
|
||||
};
|
||||
|
||||
// Client for communicating with the tf.data service worker.
|
||||
class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
public:
|
||||
DataServiceWorkerClient(absl::string_view address, absl::string_view protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
|
||||
// Fetches the next element for the specified task_id. The element's
|
||||
// compressed tensors will be stored in *element. If no element is available,
|
||||
// `*end_of_sequence` will be `true`, and `element` will be left unchanged.
|
||||
Status GetElement(int64 task_id, CompressedElement* element,
|
||||
bool* end_of_sequence);
|
||||
|
||||
protected:
|
||||
Status EnsureInitialized() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
};
|
||||
|
||||
// Creates and initializes a new tf.data service master client.
|
||||
Status CreateDataServiceMasterClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
std::unique_ptr<DataServiceMasterClient>* out);
|
||||
|
||||
// Creates and initializes a new tf.data service worker client.
|
||||
Status CreateDataServiceWorkerClient(
|
||||
absl::string_view address, absl::string_view protocol,
|
||||
std::unique_ptr<DataServiceWorkerClient>* out);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
@ -34,97 +36,33 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
namespace {
|
||||
Status RegisterDataset(MasterService::Stub* master_stub,
|
||||
const GraphDef& dataset_graph, int64* dataset_id) {
|
||||
grpc_impl::ClientContext ctx;
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = dataset_graph;
|
||||
GetOrRegisterDatasetResponse resp;
|
||||
grpc::Status s = master_stub->GetOrRegisterDataset(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to register dataset", s);
|
||||
}
|
||||
*dataset_id = resp.dataset_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateJob(MasterService::Stub* master_stub, int64 dataset_id,
|
||||
int64* job_id) {
|
||||
grpc_impl::ClientContext ctx;
|
||||
CreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
CreateJobResponse resp;
|
||||
grpc::Status s = master_stub->CreateJob(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to begin epoch", s);
|
||||
}
|
||||
*job_id = resp.job_id();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTasks(MasterService::Stub* master_stub, int64 job_id,
|
||||
std::vector<TaskInfo>* tasks) {
|
||||
grpc_impl::ClientContext ctx;
|
||||
GetTasksRequest req;
|
||||
req.set_job_id(job_id);
|
||||
GetTasksResponse resp;
|
||||
grpc::Status s = master_stub->GetTasks(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get tasks", s);
|
||||
}
|
||||
tasks->clear();
|
||||
for (auto& task : resp.task_info()) {
|
||||
tasks->push_back(task);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetElement(WorkerService::Stub* worker_stub, int64 task_id,
|
||||
std::vector<Tensor>* element, bool* end_of_sequence) {
|
||||
grpc_impl::ClientContext ctx;
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
GetElementResponse resp;
|
||||
grpc::Status s = worker_stub->GetElement(&ctx, req, &resp);
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get element", s);
|
||||
}
|
||||
*end_of_sequence = resp.end_of_sequence();
|
||||
if (!*end_of_sequence) {
|
||||
const CompressedElement& compressed = resp.compressed_element();
|
||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, element));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
constexpr const char kProtocol[] = "grpc+local";
|
||||
|
||||
Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
|
||||
std::vector<std::vector<Tensor>> expected_output) {
|
||||
auto worker_channel = grpc::CreateChannel(
|
||||
worker_address, grpc::experimental::LocalCredentials(LOCAL_TCP));
|
||||
std::unique_ptr<WorkerService::Stub> worker_stub =
|
||||
WorkerService::NewStub(worker_channel);
|
||||
DataServiceWorkerClient worker(worker_address, kProtocol);
|
||||
for (std::vector<Tensor>& expected : expected_output) {
|
||||
bool end_of_sequence;
|
||||
std::vector<Tensor> element;
|
||||
CompressedElement compressed;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetElement(worker_stub.get(), task_id, &element, &end_of_sequence));
|
||||
worker.GetElement(task_id, &compressed, &end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
return errors::Internal("Reached end of sequence too early.");
|
||||
}
|
||||
std::vector<Tensor> element;
|
||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element));
|
||||
TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected,
|
||||
/*compare_order=*/true));
|
||||
}
|
||||
// Call GetElement a couple more times to verify tha end_of_sequence keeps
|
||||
// returning true.
|
||||
bool end_of_sequence;
|
||||
std::vector<Tensor> element;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetElement(worker_stub.get(), task_id, &element, &end_of_sequence));
|
||||
CompressedElement compressed;
|
||||
TF_RETURN_IF_ERROR(worker.GetElement(task_id, &compressed, &end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
return errors::Internal("Expected end_of_sequence to be true");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetElement(worker_stub.get(), task_id, &element, &end_of_sequence));
|
||||
TF_RETURN_IF_ERROR(worker.GetElement(task_id, &compressed, &end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
return errors::Internal("Expected end_of_sequence to be true");
|
||||
}
|
||||
@ -138,22 +76,21 @@ TEST(DataService, IterateDatasetOneWorker) {
|
||||
TF_ASSERT_OK(cluster.Initialize());
|
||||
test_util::GraphDefTestCase test_case;
|
||||
TF_ASSERT_OK(test_util::map_test_case(&test_case));
|
||||
auto master_channel = grpc::CreateChannel(
|
||||
cluster.MasterAddress(), grpc::experimental::LocalCredentials(LOCAL_TCP));
|
||||
std::unique_ptr<MasterService::Stub> master_stub =
|
||||
MasterService::NewStub(master_channel);
|
||||
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
|
||||
|
||||
int64 dataset_id;
|
||||
TF_ASSERT_OK(
|
||||
RegisterDataset(master_stub.get(), test_case.graph_def, &dataset_id));
|
||||
TF_ASSERT_OK(master.RegisterDataset(test_case.graph_def, &dataset_id));
|
||||
int64 job_id;
|
||||
TF_ASSERT_OK(CreateJob(master_stub.get(), dataset_id, &job_id));
|
||||
TF_ASSERT_OK(
|
||||
master.CreateJob(dataset_id, ProcessingMode::PARALLEL_EPOCHS, &job_id));
|
||||
std::vector<TaskInfo> tasks;
|
||||
TF_ASSERT_OK(GetTasks(master_stub.get(), job_id, &tasks));
|
||||
bool job_finished;
|
||||
TF_ASSERT_OK(master.GetTasks(job_id, &tasks, &job_finished));
|
||||
ASSERT_EQ(tasks.size(), 1);
|
||||
ASSERT_EQ(tasks[0].worker_address(), cluster.WorkerAddress(0));
|
||||
EXPECT_EQ(tasks[0].worker_address(), cluster.WorkerAddress(0));
|
||||
EXPECT_FALSE(job_finished);
|
||||
|
||||
TF_ASSERT_OK(CheckWorkerOutput(tasks[0].worker_address(), tasks[0].id(),
|
||||
TF_EXPECT_OK(CheckWorkerOutput(tasks[0].worker_address(), tasks[0].id(),
|
||||
test_case.output));
|
||||
}
|
||||
|
||||
@ -162,23 +99,22 @@ TEST(DataService, IterateDatasetTwoWorkers) {
|
||||
TF_ASSERT_OK(cluster.Initialize());
|
||||
test_util::GraphDefTestCase test_case;
|
||||
TF_ASSERT_OK(test_util::map_test_case(&test_case));
|
||||
auto master_channel = grpc::CreateChannel(
|
||||
cluster.MasterAddress(), grpc::experimental::LocalCredentials(LOCAL_TCP));
|
||||
std::unique_ptr<MasterService::Stub> master_stub =
|
||||
MasterService::NewStub(master_channel);
|
||||
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
|
||||
|
||||
int64 dataset_id;
|
||||
TF_ASSERT_OK(
|
||||
RegisterDataset(master_stub.get(), test_case.graph_def, &dataset_id));
|
||||
TF_ASSERT_OK(master.RegisterDataset(test_case.graph_def, &dataset_id));
|
||||
int64 job_id;
|
||||
TF_ASSERT_OK(CreateJob(master_stub.get(), dataset_id, &job_id));
|
||||
TF_ASSERT_OK(
|
||||
master.CreateJob(dataset_id, ProcessingMode::PARALLEL_EPOCHS, &job_id));
|
||||
std::vector<TaskInfo> tasks;
|
||||
TF_ASSERT_OK(GetTasks(master_stub.get(), job_id, &tasks));
|
||||
ASSERT_EQ(tasks.size(), 2);
|
||||
bool job_finished;
|
||||
TF_EXPECT_OK(master.GetTasks(job_id, &tasks, &job_finished));
|
||||
EXPECT_EQ(tasks.size(), 2);
|
||||
EXPECT_FALSE(job_finished);
|
||||
|
||||
// Each worker produces the full dataset.
|
||||
for (TaskInfo task : tasks) {
|
||||
TF_ASSERT_OK(
|
||||
TF_EXPECT_OK(
|
||||
CheckWorkerOutput(task.worker_address(), task.id(), test_case.output));
|
||||
}
|
||||
}
|
||||
@ -188,26 +124,26 @@ TEST(DataService, AddWorkerMidEpoch) {
|
||||
TF_ASSERT_OK(cluster.Initialize());
|
||||
test_util::GraphDefTestCase test_case;
|
||||
TF_ASSERT_OK(test_util::map_test_case(&test_case));
|
||||
auto master_channel = grpc::CreateChannel(
|
||||
cluster.MasterAddress(), grpc::experimental::LocalCredentials(LOCAL_TCP));
|
||||
std::unique_ptr<MasterService::Stub> master_stub =
|
||||
MasterService::NewStub(master_channel);
|
||||
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
|
||||
|
||||
int64 dataset_id;
|
||||
TF_ASSERT_OK(
|
||||
RegisterDataset(master_stub.get(), test_case.graph_def, &dataset_id));
|
||||
TF_ASSERT_OK(master.RegisterDataset(test_case.graph_def, &dataset_id));
|
||||
int64 job_id;
|
||||
TF_ASSERT_OK(CreateJob(master_stub.get(), dataset_id, &job_id));
|
||||
TF_ASSERT_OK(
|
||||
master.CreateJob(dataset_id, ProcessingMode::PARALLEL_EPOCHS, &job_id));
|
||||
std::vector<TaskInfo> tasks;
|
||||
TF_ASSERT_OK(GetTasks(master_stub.get(), job_id, &tasks));
|
||||
ASSERT_EQ(tasks.size(), 1);
|
||||
bool job_finished;
|
||||
TF_ASSERT_OK(master.GetTasks(job_id, &tasks, &job_finished));
|
||||
EXPECT_EQ(tasks.size(), 1);
|
||||
EXPECT_FALSE(job_finished);
|
||||
TF_ASSERT_OK(cluster.AddWorker());
|
||||
TF_ASSERT_OK(GetTasks(master_stub.get(), job_id, &tasks));
|
||||
ASSERT_EQ(tasks.size(), 2);
|
||||
TF_EXPECT_OK(master.GetTasks(job_id, &tasks, &job_finished));
|
||||
EXPECT_EQ(tasks.size(), 2);
|
||||
EXPECT_FALSE(job_finished);
|
||||
|
||||
// Each worker produces the full dataset.
|
||||
for (TaskInfo task : tasks) {
|
||||
TF_ASSERT_OK(
|
||||
TF_EXPECT_OK(
|
||||
CheckWorkerOutput(task.worker_address(), task.id(), test_case.output));
|
||||
}
|
||||
}
|
||||
|
@ -730,6 +730,12 @@ class FunctionLibraryRuntime {
|
||||
virtual void Run(const Options& opts, Handle handle,
|
||||
CallFrameInterface* call_frame, DoneCallback done) = 0;
|
||||
|
||||
virtual Status RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) = 0;
|
||||
virtual Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) = 0;
|
||||
|
||||
// Creates a "kernel" for the given NodeProperties "props".
|
||||
//
|
||||
// If succeeds, returns OK and the caller takes the ownership of the
|
||||
|
@ -154,14 +154,6 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mkl_related_tests",
|
||||
srcs = [
|
||||
"mkl_layout_pass_test.cc",
|
||||
"mkl_tfconversion_pass_test.cc",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobile_srcs_only_runtime",
|
||||
srcs = [
|
||||
|
@ -26,12 +26,128 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset";
|
||||
constexpr char kShuffleDataset[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||
constexpr char kRepeatDataset[] = "RepeatDataset";
|
||||
constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset";
|
||||
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
|
||||
|
||||
constexpr char kOutputShapes[] = "output_shapes";
|
||||
constexpr char kOutputTypes[] = "output_types";
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
|
||||
Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node,
|
||||
MutableGraphView* graph, GraphDef* output,
|
||||
NodeDef* fused_node) {
|
||||
fused_node->set_op(kShuffleAndRepeatDataset);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
|
||||
fused_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
fused_node->add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
fused_node->add_input(shuffle_node.input(1));
|
||||
|
||||
// Set the `seed` input argument.
|
||||
fused_node->add_input(shuffle_node.input(2));
|
||||
|
||||
// Set the `seed2` input argument.
|
||||
fused_node->add_input(shuffle_node.input(3));
|
||||
|
||||
// Set the `count` input argument.
|
||||
fused_node->add_input(repeat_node.input(1));
|
||||
|
||||
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
|
||||
// attributes.
|
||||
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node,
|
||||
MutableGraphView* graph, GraphDef* output,
|
||||
NodeDef* fused_node) {
|
||||
fused_node->set_op(kShuffleAndRepeatDatasetV2);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output,
|
||||
fused_node);
|
||||
|
||||
NodeDef zero_node = *graph_utils::AddScalarConstNode<int64>(0, graph);
|
||||
|
||||
// Set the `input` input argument.
|
||||
fused_node->add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
fused_node->add_input(shuffle_node.input(1));
|
||||
|
||||
// Default the `seed` input argument to 0.
|
||||
fused_node->add_input(zero_node.name());
|
||||
|
||||
// Default the `seed2` input argument to 0.
|
||||
fused_node->add_input(zero_node.name());
|
||||
|
||||
// Set the `count` input argument.
|
||||
fused_node->add_input(repeat_node.input(1));
|
||||
|
||||
// Set the `seed_generator` input argument.
|
||||
fused_node->add_input(shuffle_node.input(2));
|
||||
|
||||
// Set `output_types` and `output_shapes` attributes.
|
||||
for (auto key : {kOutputShapes, kOutputTypes}) {
|
||||
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||
}
|
||||
|
||||
// Default the `reshuffle_each_iteration` attribute to true.
|
||||
(*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node,
|
||||
MutableGraphView* graph, GraphDef* output,
|
||||
NodeDef* fused_node) {
|
||||
fused_node->set_op(kShuffleAndRepeatDatasetV2);
|
||||
graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
|
||||
fused_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
fused_node->add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
fused_node->add_input(shuffle_node.input(1));
|
||||
|
||||
// Set the `seed` input argument.
|
||||
fused_node->add_input(shuffle_node.input(2));
|
||||
|
||||
// Set the `seed2` input argument.
|
||||
fused_node->add_input(shuffle_node.input(3));
|
||||
|
||||
// Set the `count` input argument.
|
||||
fused_node->add_input(repeat_node.input(1));
|
||||
|
||||
// Set the `seed_generator` input argument.
|
||||
fused_node->add_input(shuffle_node.input(4));
|
||||
|
||||
// Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
|
||||
// attributes.
|
||||
for (auto key : {kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
graph_utils::CopyAttribute(key, shuffle_node, fused_node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -42,65 +158,46 @@ Status ShuffleAndRepeatFusion::OptimizeAndCollectStats(
|
||||
MutableGraphView graph(output);
|
||||
absl::flat_hash_set<string> nodes_to_delete;
|
||||
|
||||
auto make_shuffle_and_repeat_node = [&output](const NodeDef& shuffle_node,
|
||||
const NodeDef& repeat_node) {
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kFusedOpName);
|
||||
graph_utils::SetUniqueGraphNodeName(kFusedOpName, output, &new_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
new_node.add_input(shuffle_node.input(0));
|
||||
|
||||
// Set the `buffer_size` input argument.
|
||||
new_node.add_input(shuffle_node.input(1));
|
||||
|
||||
// Set the `seed` input argument.
|
||||
new_node.add_input(shuffle_node.input(2));
|
||||
|
||||
// Set the `seed2` input argument.
|
||||
new_node.add_input(shuffle_node.input(3));
|
||||
|
||||
// Set the `count` input argument.
|
||||
new_node.add_input(repeat_node.input(1));
|
||||
|
||||
// Set `output_types` and `output_shapes` attributes.
|
||||
for (auto key : {"output_shapes", "output_types"}) {
|
||||
graph_utils::CopyAttribute(key, repeat_node, &new_node);
|
||||
}
|
||||
return new_node;
|
||||
};
|
||||
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (node.op() != "RepeatDataset") {
|
||||
for (const NodeDef& repeat_node : item.graph.node()) {
|
||||
if (repeat_node.op() != kRepeatDataset) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Use a more descriptive variable name now that we know the node type.
|
||||
const NodeDef& repeat_node = node;
|
||||
NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
|
||||
const NodeDef& shuffle_node =
|
||||
*graph_utils::GetInputNode(repeat_node, graph);
|
||||
|
||||
if (node2->op() != "ShuffleDataset") {
|
||||
NodeDef fused_node;
|
||||
if (shuffle_node.op() == kShuffleDataset) {
|
||||
TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node,
|
||||
&graph, output, &fused_node));
|
||||
} else if (shuffle_node.op() == kShuffleDatasetV2) {
|
||||
TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node,
|
||||
&graph, output, &fused_node));
|
||||
|
||||
} else if (shuffle_node.op() == kShuffleDatasetV3) {
|
||||
TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node,
|
||||
&graph, output, &fused_node));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Use a more descriptive variable name now that we know the node type.
|
||||
const NodeDef& shuffle_node = *node2;
|
||||
|
||||
// TODO(b/129712758): Remove when the fused kernel supports disabling
|
||||
// reshuffling for each iteration.
|
||||
if (HasNodeAttr(shuffle_node, "reshuffle_each_iteration") &&
|
||||
!shuffle_node.attr().at("reshuffle_each_iteration").b()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeDef* shuffle_and_repeat_node =
|
||||
graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
|
||||
NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node));
|
||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
|
||||
shuffle_and_repeat_node->name()));
|
||||
shuffle_and_repeat_node.name()));
|
||||
// Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of
|
||||
// control dependencies.
|
||||
TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(),
|
||||
shuffle_and_repeat_node.name()));
|
||||
|
||||
// Mark the `Shuffle` and `Repeat` nodes for removal.
|
||||
nodes_to_delete.insert(shuffle_node.name());
|
||||
nodes_to_delete.insert(repeat_node.name());
|
||||
// Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of
|
||||
// them needs to be preserved).
|
||||
const auto nodes_to_preserve = item.NodesToPreserve();
|
||||
if (nodes_to_preserve.find(shuffle_node.name()) ==
|
||||
nodes_to_preserve.end() &&
|
||||
nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) {
|
||||
nodes_to_delete.insert(shuffle_node.name());
|
||||
nodes_to_delete.insert(repeat_node.name());
|
||||
}
|
||||
stats->num_changes++;
|
||||
}
|
||||
|
||||
|
@ -25,17 +25,21 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
constexpr char kOutputShapes[] = "output_shapes";
|
||||
constexpr char kOutputTypes[] = "output_types";
|
||||
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV1AndRepeat) {
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue("output_shapes", &shapes_attr);
|
||||
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue("output_types", &types_attr);
|
||||
common_attrs[1] = std::make_pair("output_types", types_attr);
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
@ -59,6 +63,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
shuffle_inputs[3] = seed2_node->name();
|
||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||
"", "ShuffleDataset", shuffle_inputs, common_attrs, &graph);
|
||||
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
|
||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
std::vector<string> repeat_inputs(2);
|
||||
@ -85,12 +90,148 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||
for (const auto &attr :
|
||||
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||
shuffle_node->attr().at(attr)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV2AndRepeat) {
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
|
||||
std::vector<string> range_inputs(3);
|
||||
range_inputs[0] = start_node->name();
|
||||
range_inputs[1] = stop_node->name();
|
||||
range_inputs[2] = step_node->name();
|
||||
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
|
||||
common_attrs, &graph);
|
||||
|
||||
NodeDef *buffer_size_node =
|
||||
graph_utils::AddScalarConstNode<int64>(128, &graph);
|
||||
NodeDef *seed_generator_node =
|
||||
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
|
||||
std::vector<string> shuffle_inputs(3);
|
||||
shuffle_inputs[0] = range_node->name();
|
||||
shuffle_inputs[1] = buffer_size_node->name();
|
||||
shuffle_inputs[2] = seed_generator_node->name();
|
||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||
"", "ShuffleDatasetV2", shuffle_inputs, common_attrs, &graph);
|
||||
|
||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
std::vector<string> repeat_inputs(2);
|
||||
repeat_inputs[0] = shuffle_node->name();
|
||||
repeat_inputs[1] = count_node->name();
|
||||
NodeDef *repeat_node = graph_utils::AddNode(
|
||||
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
|
||||
|
||||
ShuffleAndRepeatFusion optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
|
||||
EXPECT_TRUE(
|
||||
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"),
|
||||
repeat_node->attr().at("output_shapes")));
|
||||
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
NodeDef shuffle_and_repeat_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(2));
|
||||
for (const auto &attr : {kOutputShapes, kOutputTypes}) {
|
||||
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||
shuffle_node->attr().at(attr)));
|
||||
}
|
||||
EXPECT_TRUE(shuffle_and_repeat_node.attr().at(kReshuffleEachIteration).b());
|
||||
}
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, FuseShuffleV3AndRepeat) {
|
||||
GrapplerItem item;
|
||||
MutableGraphView graph(&item.graph);
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
NodeDef *step_node = graph_utils::AddScalarConstNode<int64>(1, &graph);
|
||||
|
||||
std::vector<string> range_inputs(3);
|
||||
range_inputs[0] = start_node->name();
|
||||
range_inputs[1] = stop_node->name();
|
||||
range_inputs[2] = step_node->name();
|
||||
NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs,
|
||||
common_attrs, &graph);
|
||||
|
||||
NodeDef *buffer_size_node =
|
||||
graph_utils::AddScalarConstNode<int64>(128, &graph);
|
||||
NodeDef *seed_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
NodeDef *seed2_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
NodeDef *seed_generator_node =
|
||||
graph_utils::AddScalarConstNode<StringPiece>("dummy_resource", &graph);
|
||||
std::vector<string> shuffle_inputs(5);
|
||||
shuffle_inputs[0] = range_node->name();
|
||||
shuffle_inputs[1] = buffer_size_node->name();
|
||||
shuffle_inputs[2] = seed_node->name();
|
||||
shuffle_inputs[3] = seed2_node->name();
|
||||
shuffle_inputs[4] = seed_generator_node->name();
|
||||
NodeDef *shuffle_node = graph_utils::AddNode(
|
||||
"", "ShuffleDatasetV3", shuffle_inputs, common_attrs, &graph);
|
||||
(*shuffle_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
|
||||
|
||||
NodeDef *count_node = graph_utils::AddScalarConstNode<int64>(-1, &graph);
|
||||
std::vector<string> repeat_inputs(2);
|
||||
repeat_inputs[0] = shuffle_node->name();
|
||||
repeat_inputs[1] = count_node->name();
|
||||
NodeDef *repeat_node = graph_utils::AddNode(
|
||||
"", "RepeatDataset", repeat_inputs, common_attrs, &graph);
|
||||
|
||||
ShuffleAndRepeatFusion optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(shuffle_node->name(), output));
|
||||
EXPECT_FALSE(
|
||||
graph_utils::ContainsGraphNodeWithName(repeat_node->name(), output));
|
||||
EXPECT_TRUE(
|
||||
AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"),
|
||||
repeat_node->attr().at("output_types")));
|
||||
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
NodeDef shuffle_and_repeat_node = output.node(
|
||||
graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDatasetV2", output));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 6);
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
|
||||
EXPECT_EQ(shuffle_and_repeat_node.input(5), shuffle_node->input(4));
|
||||
for (const auto &attr :
|
||||
{kOutputShapes, kOutputTypes, kReshuffleEachIteration}) {
|
||||
EXPECT_TRUE(AreAttrValuesEqual(shuffle_and_repeat_node.attr().at(attr),
|
||||
shuffle_node->attr().at(attr)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
||||
@ -99,11 +240,11 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
||||
|
||||
std::vector<std::pair<string, AttrValue>> common_attrs(2);
|
||||
AttrValue shapes_attr;
|
||||
SetAttrValue("output_shapes", &shapes_attr);
|
||||
common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
|
||||
SetAttrValue(kOutputShapes, &shapes_attr);
|
||||
common_attrs[0] = std::make_pair(kOutputShapes, shapes_attr);
|
||||
AttrValue types_attr;
|
||||
SetAttrValue("output_types", &types_attr);
|
||||
common_attrs[1] = std::make_pair("output_types", types_attr);
|
||||
SetAttrValue(kOutputTypes, &types_attr);
|
||||
common_attrs[1] = std::make_pair(kOutputTypes, types_attr);
|
||||
|
||||
NodeDef *start_node = graph_utils::AddScalarConstNode<int64>(0, &graph);
|
||||
NodeDef *stop_node = graph_utils::AddScalarConstNode<int64>(10, &graph);
|
||||
|
@ -4540,6 +4540,9 @@ tf_cuda_cc_test(
|
||||
name = "split_v_op_test",
|
||||
size = "small",
|
||||
srcs = ["split_v_op_test.cc"],
|
||||
tags = [
|
||||
"no_windows", # split_v_op uses lrand48 which does not exist on Windows
|
||||
],
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
|
@ -671,20 +671,13 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
|
||||
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
|
||||
ret_types_);
|
||||
Notification n;
|
||||
Status s;
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return absl::StrCat(
|
||||
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
|
||||
s.Update(func_status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
return frame.ConsumeRetvals(rets);
|
||||
}
|
||||
|
||||
@ -709,9 +702,6 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
|
||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||
ret_types_);
|
||||
Notification n;
|
||||
Status s;
|
||||
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return absl::StrCat(
|
||||
@ -719,12 +709,7 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
|
||||
s.Update(func_status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
return frame.ConsumeRetvals(rets);
|
||||
}
|
||||
|
||||
@ -748,21 +733,13 @@ Status InstantiatedCapturedFunction::RunInstantiated(
|
||||
|
||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||
ret_types_);
|
||||
Notification n;
|
||||
Status s;
|
||||
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
return absl::StrCat("InstantiatedCapturedFunction::RunInstantiated#id=",
|
||||
f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
lib_->Run(f_opts, f_handle_, &frame, [&n, &s](const Status& func_status) {
|
||||
s.Update(func_status);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
return frame.ConsumeRetvals(rets);
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Description:
|
||||
# Contains experimental kernels for datasets and iterators.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
@ -132,12 +131,9 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/data/service:common_proto_cc",
|
||||
"//tensorflow/core/data/service:compression_utils",
|
||||
"//tensorflow/core/data/service:credentials_factory",
|
||||
"//tensorflow/core/data/service:grpc_util",
|
||||
"//tensorflow/core/data/service:master_cc_grpc_proto",
|
||||
"//tensorflow/core/data/service:master_proto_cc",
|
||||
"//tensorflow/core/data/service:worker_cc_grpc_proto",
|
||||
"//tensorflow/core/data/service:data_service",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
@ -145,7 +141,6 @@ tf_kernel_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -158,13 +153,9 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/data/service:credentials_factory",
|
||||
"//tensorflow/core/data/service:grpc_util",
|
||||
"//tensorflow/core/data/service:master_cc_grpc_proto",
|
||||
"//tensorflow/core/data/service:master_proto_cc",
|
||||
"//tensorflow/core/data/service:data_service",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,18 +18,12 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/impl/codegen/server_context.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/compression_utils.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/master.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/model.h"
|
||||
@ -155,8 +149,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
"over the dataset via `create_iterator(dataset, job_token).`");
|
||||
}
|
||||
job_id_ = ctx->job_token().job_id();
|
||||
TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
|
||||
dataset()->protocol_, &credentials_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -212,7 +204,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
int64 task_id;
|
||||
// Cached address of the worker for task `task_id`.
|
||||
std::string address;
|
||||
std::unique_ptr<WorkerService::Stub> worker_stub;
|
||||
std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
std::unique_ptr<Thread> thread;
|
||||
bool end_of_sequence = false;
|
||||
// Indicates that the thread has finished running.
|
||||
@ -225,9 +217,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// the list of tasks changes.
|
||||
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
|
||||
VLOG(3) << "Starting task handler manager";
|
||||
auto channel = ::grpc::CreateChannel(dataset()->address_, credentials_);
|
||||
std::unique_ptr<MasterService::Stub> master_stub =
|
||||
MasterService::NewStub(channel);
|
||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
||||
|
||||
uint64 next_check = Env::Default()->NowMicros();
|
||||
while (true) {
|
||||
@ -244,29 +234,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
return;
|
||||
}
|
||||
}
|
||||
UpdateTaskThreads(master_stub.get(), ctx.get());
|
||||
UpdateTaskThreads(&master, ctx.get());
|
||||
next_check = Env::Default()->NowMicros() +
|
||||
dataset()->task_refresh_interval_ms_ * 1000;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateTaskThreads(MasterService::Stub* master_stub,
|
||||
void UpdateTaskThreads(DataServiceMasterClient* master,
|
||||
IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Updating task handler threads";
|
||||
GetTasksResponse resp;
|
||||
GetTasksRequest req;
|
||||
req.set_job_id(job_id_);
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status s = master_stub->GetTasks(&client_ctx, req, &resp);
|
||||
std::vector<TaskInfo> tasks;
|
||||
bool job_finished;
|
||||
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
|
||||
if (!s.ok()) {
|
||||
LOG(INFO) << "Failed to get task info for job id " << job_id_ << ": "
|
||||
<< s.error_message() << "(" << s.error_code() << ")";
|
||||
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
|
||||
<< s;
|
||||
return;
|
||||
}
|
||||
absl::flat_hash_set<int64> task_ids;
|
||||
mutex_lock l(mu_);
|
||||
job_finished_ = resp.job_finished();
|
||||
for (auto& task : resp.task_info()) {
|
||||
job_finished_ = job_finished;
|
||||
for (auto& task : tasks) {
|
||||
task_ids.insert(task.id());
|
||||
if (task_threads_.contains(task.id())) {
|
||||
continue;
|
||||
@ -315,11 +303,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
<< task_handler->task_id << " with worker address "
|
||||
<< task_handler->address;
|
||||
while (true) {
|
||||
if (!task_handler->worker_stub) {
|
||||
Status s = CreateWorkerStub(task_handler->address,
|
||||
&task_handler->worker_stub);
|
||||
if (!task_handler->worker) {
|
||||
Status s = CreateDataServiceWorkerClient(task_handler->address,
|
||||
dataset()->protocol_,
|
||||
&task_handler->worker);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to create a worker stub for "
|
||||
LOG(WARNING) << "Failed to create a worker client for "
|
||||
<< task_handler->address << ": " << s;
|
||||
}
|
||||
}
|
||||
@ -359,9 +348,11 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// `results_`.
|
||||
Status FetchElement(TaskThread* task_handler, int64 deadline_micros) {
|
||||
VLOG(3) << "Fetching an element for task id " << task_handler->task_id;
|
||||
GetElementResponse resp;
|
||||
CompressedElement compressed;
|
||||
bool end_of_sequence;
|
||||
for (int num_retries = 0;; ++num_retries) {
|
||||
Status s = RequestElement(task_handler, &resp);
|
||||
Status s = task_handler->worker->GetElement(
|
||||
task_handler->task_id, &compressed, &end_of_sequence);
|
||||
if (s.ok()) {
|
||||
break;
|
||||
}
|
||||
@ -395,12 +386,11 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
std::vector<Tensor> element;
|
||||
if (!resp.end_of_sequence()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
service_util::Uncompress(resp.compressed_element(), &element));
|
||||
if (!end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element));
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
if (resp.end_of_sequence()) {
|
||||
if (end_of_sequence) {
|
||||
task_handler->end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -410,31 +400,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RequestElement(TaskThread* task_handler, GetElementResponse* resp) {
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_handler->task_id);
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status s =
|
||||
task_handler->worker_stub->GetElement(&client_ctx, req, resp);
|
||||
if (s.ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return grpc_util::WrapError("Failed to request an element", s);
|
||||
}
|
||||
|
||||
Status CreateWorkerStub(const std::string& worker_address,
|
||||
std::unique_ptr<WorkerService::Stub>* stub) {
|
||||
::grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(-1);
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
|
||||
dataset()->protocol_, &credentials));
|
||||
auto channel =
|
||||
::grpc::CreateCustomChannel(worker_address, credentials, args);
|
||||
*stub = WorkerService::NewStub(channel);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
// TODO(aaudibert): split this into a couple cvs for different conditions
|
||||
// so that we can use notify_one and avoid unnecessary wakeups.
|
||||
@ -450,7 +415,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
// Set once in Initialize().
|
||||
int64 job_id_;
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials_;
|
||||
int64 num_unfinished_tasks_ TF_GUARDED_BY(mu_) = 0;
|
||||
|
||||
bool job_finished_ = false;
|
||||
|
@ -15,11 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/data/experimental/data_service_ops.h"
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -69,26 +65,14 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||
|
||||
VLOG(3) << "Registering dataset with master at " << address
|
||||
<< ". Protocol=" << protocol;
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CredentialsFactory::CreateClientCredentials(protocol, &credentials));
|
||||
auto channel = ::grpc::CreateChannel(address, credentials);
|
||||
auto master_stub = MasterService::NewStub(channel);
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = graph_def;
|
||||
GetOrRegisterDatasetResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
auto status = master_stub->GetOrRegisterDataset(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
ctx->CtxFailure(grpc_util::WrapError("Failed to register dataset", status));
|
||||
return;
|
||||
}
|
||||
DataServiceMasterClient client(address, protocol);
|
||||
int64 dataset_id;
|
||||
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));
|
||||
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
|
||||
auto output_dataset_id = output->tensor<int64, 0>();
|
||||
output_dataset_id() = resp.dataset_id();
|
||||
output_dataset_id() = dataset_id;
|
||||
}
|
||||
|
||||
CreateJobOp::CreateJobOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
@ -114,24 +98,11 @@ void CreateJobOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseProcessingMode(processing_mode_str, &processing_mode));
|
||||
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CredentialsFactory::CreateClientCredentials(protocol, &credentials));
|
||||
auto channel = ::grpc::CreateChannel(address, credentials);
|
||||
auto master_stub = MasterService::NewStub(channel);
|
||||
CreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
req.set_processing_mode(ProcessingModeDef(processing_mode));
|
||||
CreateJobResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
auto status = master_stub->CreateJob(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
ctx->CtxFailure(grpc_util::WrapError(
|
||||
absl::StrCat("Failed to begin epoch for dataset id ", dataset_id),
|
||||
status));
|
||||
return;
|
||||
}
|
||||
JobToken token(resp.job_id());
|
||||
DataServiceMasterClient client(address, protocol);
|
||||
int64 job_id;
|
||||
OP_REQUIRES_OK(ctx, client.CreateJob(dataset_id, processing_mode, &job_id));
|
||||
|
||||
JobToken token(job_id);
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
|
||||
auto output_token = output->tensor<Variant, 0>();
|
||||
|
@ -44,13 +44,6 @@ class RegisterDatasetOp : public OpKernel {
|
||||
SerializationContext::ExternalStatePolicy external_state_policy_;
|
||||
};
|
||||
|
||||
enum class ProcessingMode : int64 {
|
||||
// Each tf.data worker processes an entire epoch.
|
||||
PARALLEL_EPOCHS = 0,
|
||||
// Processing of an epoch is distributed across all tf.data workers.
|
||||
ONE_EPOCH = 1,
|
||||
};
|
||||
|
||||
// Creates a token for reading from the tf.data service.
|
||||
//
|
||||
// The dataset_id input identifies which dataset to create a token for.
|
||||
|
@ -841,16 +841,9 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
opts.step_container = &step_container;
|
||||
opts.runner = ctx->runner();
|
||||
opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
|
||||
Notification n;
|
||||
Status factory_status;
|
||||
std::vector<Tensor> return_values;
|
||||
ctx->function_library()->Run(opts, f_handle, {}, &return_values,
|
||||
[&n, &factory_status](Status s) {
|
||||
factory_status.Update(s);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
TF_RETURN_IF_ERROR(factory_status);
|
||||
TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
|
||||
std::move(opts), f_handle, {}, &return_values));
|
||||
if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
|
||||
!TensorShapeUtils::IsScalar(return_values[0].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -44,10 +44,10 @@ namespace data {
|
||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
|
||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
|
||||
/* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
|
||||
/* static */ constexpr const char* const
|
||||
ShuffleDatasetOpBase::kReshuffleEachIteration;
|
||||
|
||||
/* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const
|
||||
ShuffleDatasetOp::kReshuffleEachIteration;
|
||||
|
||||
/* static */ constexpr const char* const
|
||||
ShuffleAndRepeatDatasetOp::kDatasetType;
|
||||
@ -72,6 +72,8 @@ constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
|
||||
constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
|
||||
constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
|
||||
constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
|
||||
constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDatasetV1";
|
||||
constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
|
||||
|
||||
ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
@ -225,6 +227,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
while (!slices_.empty() &&
|
||||
slices_.front()->start == slices_.front()->end) {
|
||||
slices_.pop_front();
|
||||
// Reinitialize the RNG state for the next epoch.
|
||||
num_random_samples_ = 0;
|
||||
seed_generator_->GenerateSeeds(&seed_, &seed2_);
|
||||
ResetRngs();
|
||||
}
|
||||
DCHECK(!slices_.empty());
|
||||
// Choose an element to produce uniformly at random from the first
|
||||
@ -663,6 +669,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
@ -679,7 +686,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
@ -695,6 +701,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
handle.container(), handle.name(), &manager);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
|
||||
"using a non-deterministically seeded generator and "
|
||||
"reshuffling each iteration.";
|
||||
@ -708,7 +715,6 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
owns_resource = true;
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
@ -790,9 +796,13 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
|
||||
AttrValue reshuffle_each_iteration;
|
||||
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||
&reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
|
||||
{}, // Attrs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -804,8 +814,83 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||
public:
|
||||
DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
|
||||
int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
|
||||
ResourceHandle&& resource_handle, bool owns_resource)
|
||||
: ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
|
||||
manager_(manager),
|
||||
owns_resource_(owns_resource),
|
||||
resource_handle_(std::move(resource_handle)),
|
||||
resource_mgr_(ctx->resource_manager()),
|
||||
seeds_(std::move(seeds)) {}
|
||||
|
||||
~DatasetV2() override {
|
||||
manager_->Unref();
|
||||
if (owns_resource_) {
|
||||
Status s = resource_mgr_->Delete<SeedGeneratorManager>(
|
||||
resource_handle_.container(), resource_handle_.name());
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string op_type() const override { return kDatasetType; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* buffer_size_node = nullptr;
|
||||
Node* seed_node = nullptr;
|
||||
Node* seed2_node = nullptr;
|
||||
Node* count_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
|
||||
Node* resource_handle_node = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = resource_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
|
||||
AttrValue reshuffle_each_iteration;
|
||||
b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
|
||||
&reshuffle_each_iteration);
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this,
|
||||
{input_graph_node, buffer_size_node, seed_node,
|
||||
seed2_node, count_node, resource_handle_node}, // Inputs
|
||||
{std::make_pair(kReshuffleEachIteration,
|
||||
reshuffle_each_iteration)}, // Attrs
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
SeedGeneratorManager* const manager_; // Owned
|
||||
const bool owns_resource_;
|
||||
const ResourceHandle resource_handle_;
|
||||
ResourceMgr* const resource_mgr_; // Not owned.
|
||||
const RandomSeeds seeds_;
|
||||
};
|
||||
|
||||
ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
|
||||
: ShuffleDatasetOpBase(ctx) {}
|
||||
: ShuffleDatasetOpBase(ctx) {
|
||||
auto& op_name = ctx->def().op();
|
||||
if (op_name == kShuffleAndRepeatDatasetV2) {
|
||||
op_version_ = 2;
|
||||
} else if (op_name == kShuffleAndRepeatDatasetV1) {
|
||||
op_version_ = 1;
|
||||
}
|
||||
if (ctx->HasAttr(kReshuffleEachIteration)) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
|
||||
}
|
||||
}
|
||||
|
||||
void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
DatasetBase* input,
|
||||
@ -826,29 +911,76 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
int64 count;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
|
||||
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
|
||||
OP_REQUIRES(ctx, count > 0 || count == -1,
|
||||
errors::InvalidArgument(
|
||||
"count must be greater than zero or equal to -1."));
|
||||
|
||||
RandomSeeds seeds(seed, seed2);
|
||||
|
||||
static std::atomic<int64> resource_id_counter(0);
|
||||
const string& container = ctx->resource_manager()->default_container();
|
||||
auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
|
||||
resource_id_counter.fetch_add(1));
|
||||
SeedGeneratorManager* manager;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager, [&seeds](SeedGeneratorManager** manager) {
|
||||
*manager = new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
if (op_version_ == 2) {
|
||||
auto handle = HandleFromInput(ctx, 5);
|
||||
SeedGeneratorManager* manager = nullptr;
|
||||
Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
|
||||
handle.container(), handle.name(), &manager);
|
||||
bool owns_resource = false;
|
||||
if (errors::IsNotFound(s)) {
|
||||
owns_resource = true;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager,
|
||||
[reshuffle = reshuffle_each_iteration_,
|
||||
&seeds](SeedGeneratorManager** manager) {
|
||||
if (reshuffle) {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
} else {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
|
||||
} else {
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
}
|
||||
|
||||
// Ownership of manager is transferred onto `Dataset`.
|
||||
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
|
||||
count, std::move(handle));
|
||||
// Ownership of manager is transferred onto `DatasetV2`.
|
||||
*output = new ShuffleAndRepeatDatasetOp::DatasetV2(
|
||||
ctx, input, buffer_size, count, std::move(seeds), manager,
|
||||
std::move(handle), owns_resource);
|
||||
} else {
|
||||
if (op_version_ != 1) {
|
||||
LOG(WARNING) << "Unsupported version of shuffle dataset op: "
|
||||
<< op_version_ << ". Defaulting to version 1.";
|
||||
}
|
||||
SeedGeneratorManager* manager;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
|
||||
container, name, &manager,
|
||||
[reshuffle = reshuffle_each_iteration_,
|
||||
&seeds](SeedGeneratorManager** manager) {
|
||||
if (reshuffle) {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new RandomSeedGenerator(seeds));
|
||||
} else {
|
||||
*manager =
|
||||
new SeedGeneratorManager(new FixedSeedGenerator(seeds));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
auto handle =
|
||||
MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
|
||||
|
||||
// Ownership of manager is transferred onto `Dataset`.
|
||||
*output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
|
||||
count, std::move(handle));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -863,6 +995,9 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
|
||||
ShuffleAndRepeatDatasetOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
|
||||
ShuffleAndRepeatDatasetOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kSeed2 = "seed2";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kReshuffleEachIteration =
|
||||
"reshuffle_each_iteration";
|
||||
|
||||
explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx);
|
||||
|
||||
@ -38,8 +40,6 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
|
||||
class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "Shuffle";
|
||||
static constexpr const char* const kReshuffleEachIteration =
|
||||
"reshuffle_each_iteration";
|
||||
|
||||
explicit ShuffleDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
@ -52,7 +52,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
|
||||
class DatasetV2;
|
||||
class DatasetV3;
|
||||
int op_version_ = 0;
|
||||
bool reshuffle_each_iteration_;
|
||||
bool reshuffle_each_iteration_ = true;
|
||||
};
|
||||
|
||||
class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
||||
@ -68,6 +68,9 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
|
||||
|
||||
private:
|
||||
class Dataset;
|
||||
class DatasetV2;
|
||||
int op_version_ = 0;
|
||||
bool reshuffle_each_iteration_ = true;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -72,10 +72,8 @@ class ShuffleDatasetParams : public DatasetParams {
|
||||
output_dtypes_);
|
||||
attr_vector->emplace_back(ShuffleDatasetOpBase::kOutputShapes,
|
||||
output_shapes_);
|
||||
if (count_ == 1) {
|
||||
attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration,
|
||||
reshuffle_each_iteration_);
|
||||
}
|
||||
attr_vector->emplace_back(ShuffleDatasetOp::kReshuffleEachIteration,
|
||||
reshuffle_each_iteration_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -297,23 +295,23 @@ std::vector<GetNextTestCase<ShuffleDatasetParams>> GetNextTestCases() {
|
||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}),
|
||||
{{2}, {6}, {1}, {3}, {9}, {5}, {0}, {8}, {7}, {4},
|
||||
{0}, {5}, {1}, {7}, {2}, {9}, {8}, {4}, {6}, {3}}),
|
||||
{{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5},
|
||||
{9}, {0}, {8}, {6}, {1}, {3}, {7}, {2}, {4}, {5}}),
|
||||
/*expected_reshuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}), {{1}, {6}, {0}, {5}, {2}, {7}, {4},
|
||||
{3}, {9}, {8}, {6}, {5}, {0}, {9},
|
||||
{4}, {7}, {2}, {8}, {1}, {3}})},
|
||||
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||
{2}, {4}, {5}, {9}, {0}, {8}, {6},
|
||||
{1}, {3}, {7}, {2}, {4}, {5}})},
|
||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}}),
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}}),
|
||||
/*expected_reshuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{1}, {0}, {2}, {0}, {1}, {2}, {2}, {1}, {0}, {0}, {1},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {0}, {2}})}};
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
|
||||
}
|
||||
|
||||
class ParameterizedGetNextTest : public ShuffleDatasetOpTest,
|
||||
@ -496,16 +494,16 @@ IteratorSaveAndRestoreTestCases() {
|
||||
{/*dataset_params=*/ShuffleDatasetParams7(),
|
||||
/*breakpoints=*/{0, 5, 22},
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(TensorShape({}), {{2}, {6}, {1}, {3}, {9}, {5}, {0},
|
||||
{8}, {7}, {4}, {0}, {5}, {1}, {7},
|
||||
{2}, {9}, {8}, {4}, {6}, {3}})},
|
||||
CreateTensors<int64>(TensorShape({}), {{9}, {0}, {8}, {6}, {1}, {3}, {7},
|
||||
{2}, {4}, {5}, {9}, {0}, {8}, {6},
|
||||
{1}, {3}, {7}, {2}, {4}, {5}})},
|
||||
{/*dataset_params=*/ShuffleDatasetParams8(),
|
||||
/*breakpoints=*/{0, 5, 20},
|
||||
/*expected_shuffle_outputs=*/
|
||||
CreateTensors<int64>(
|
||||
TensorShape({}),
|
||||
{{1}, {2}, {0}, {1}, {2}, {0}, {1}, {0}, {2}, {1}, {0},
|
||||
{2}, {0}, {2}, {1}, {0}, {1}, {2}, {1}, {2}, {0}})}};
|
||||
{{2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0},
|
||||
{1}, {2}, {0}, {1}, {2}, {0}, {1}, {2}, {0}, {1}})}};
|
||||
}
|
||||
|
||||
class ParameterizedIteratorSaveAndRestoreTest
|
||||
|
@ -37,3 +37,49 @@ op {
|
||||
minimum: 1
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ShuffleAndRepeatDataset"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "buffer_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed2"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "count"
|
||||
type: DT_INT64
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "reshuffle_each_iteration"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,51 @@
|
||||
op {
|
||||
name: "ShuffleAndRepeatDatasetV2"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "buffer_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed2"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "count"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed_generator"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "reshuffle_each_iteration"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
@ -507,6 +507,7 @@ REGISTER_OP("ShuffleAndRepeatDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("reshuffle_each_iteration: bool = true")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, seed2, and count should be scalars.
|
||||
@ -517,6 +518,28 @@ REGISTER_OP("ShuffleAndRepeatDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ShuffleAndRepeatDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("buffer_size: int64")
|
||||
.Input("seed: int64")
|
||||
.Input("seed2: int64")
|
||||
.Input("count: int64")
|
||||
.Input("seed_generator: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("reshuffle_each_iteration: bool = true")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, seed2, count, and seed_generator should be scalars.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("AnonymousMemoryCache")
|
||||
.Output("handle: resource")
|
||||
.Output("deleter: variant")
|
||||
|
@ -42275,6 +42275,64 @@ op {
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "reshuffle_each_iteration"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "ShuffleAndRepeatDatasetV2"
|
||||
input_arg {
|
||||
name: "input_dataset"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "buffer_size"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed2"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "count"
|
||||
type: DT_INT64
|
||||
}
|
||||
input_arg {
|
||||
name: "seed_generator"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
output_arg {
|
||||
name: "handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
attr {
|
||||
name: "reshuffle_each_iteration"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "output_types"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "output_shapes"
|
||||
type: "list(shape)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "ShuffleDataset"
|
||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 386 // Updated: 2020/4/29
|
||||
#define TF_GRAPH_DEF_VERSION 387 // Updated: 2020/4/30
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -22,6 +22,24 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool CanUseCudnn() {
|
||||
static bool is_enabled = [] {
|
||||
bool is_enabled = true;
|
||||
// TODO(b/155239286): Remove TF_USE_CUDNN after TF 2.3 is released.
|
||||
Status status =
|
||||
ReadBoolFromEnvVar("TF_USE_CUDNN", /*default_val=*/true, &is_enabled);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
if (!is_enabled) {
|
||||
LOG(WARNING) << "The environmental variable TF_USE_CUDNN is deprecated "
|
||||
"and will be ignored in the future";
|
||||
}
|
||||
return is_enabled;
|
||||
}();
|
||||
return is_enabled;
|
||||
}
|
||||
|
||||
#define ADD_BOOL_CUDNN_FLAG(func_name, flag_name, default_value) \
|
||||
bool func_name() { \
|
||||
bool value = default_value; \
|
||||
@ -32,7 +50,6 @@ namespace tensorflow {
|
||||
return value; \
|
||||
}
|
||||
|
||||
ADD_BOOL_CUDNN_FLAG(CanUseCudnn, TF_USE_CUDNN, true);
|
||||
ADD_BOOL_CUDNN_FLAG(CudnnUseAutotune, TF_CUDNN_USE_AUTOTUNE, true);
|
||||
// Whether to auto-tuning Cudnn RNN forward and backward pass to pick
|
||||
// statistically the best cudnnRNNAlgo_t and cudnnMathType_t.
|
||||
|
@ -16989,6 +16989,17 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x
|
||||
return op.Output(0), op.Output(1), op.Output(2)
|
||||
}
|
||||
|
||||
// ShuffleAndRepeatDatasetAttr is an optional argument to ShuffleAndRepeatDataset.
|
||||
type ShuffleAndRepeatDatasetAttr func(optionalAttr)
|
||||
|
||||
// ShuffleAndRepeatDatasetReshuffleEachIteration sets the optional reshuffle_each_iteration attribute to value.
|
||||
// If not specified, defaults to true
|
||||
func ShuffleAndRepeatDatasetReshuffleEachIteration(value bool) ShuffleAndRepeatDatasetAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["reshuffle_each_iteration"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a dataset that shuffles and repeats elements from `input_dataset`
|
||||
//
|
||||
// pseudorandomly.
|
||||
@ -17006,11 +17017,14 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x
|
||||
// should be repeated. The default is `-1`, which results in infinite repetition.
|
||||
//
|
||||
//
|
||||
func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
|
||||
func ShuffleAndRepeatDataset(scope *Scope, input_dataset tf.Output, buffer_size tf.Output, seed tf.Output, seed2 tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape, optional ...ShuffleAndRepeatDatasetAttr) (handle tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ShuffleAndRepeatDataset",
|
||||
Input: []tf.Input{
|
||||
@ -26814,7 +26828,7 @@ func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output)
|
||||
//
|
||||
// @tf.function
|
||||
// def foo(x, y):
|
||||
// return = mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
|
||||
// return mlir_passthrough_op([x, y], mlir_module, Toutputs=[tf.float32])
|
||||
//
|
||||
// graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.TensorSpec([10], tf.float32)).graph.as_graph_def()
|
||||
// ```
|
||||
|
@ -913,6 +913,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_SEGMENT_SUM:
|
||||
return kTfLiteOk;
|
||||
}
|
||||
return kTfLiteError;
|
||||
} // NOLINT[readability/fn_size]
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -1,52 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn_init.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "remote.h" // NOLINT
|
||||
#include "rpcmem.h" // NOLINT
|
||||
#include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/soc_model.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
// Version 1.17
|
||||
static const int kHexagonNNVersion = 136960;
|
||||
#pragma weak remote_handle_control // Declare it as a weak symbol
|
||||
void hexagon_nn_global_init() {
|
||||
rpcmem_init();
|
||||
// Non-domains QoS invocation
|
||||
struct remote_rpc_control_latency data;
|
||||
data.enable = 1;
|
||||
if (remote_handle_control) { // Check if API is available before invoking
|
||||
remote_handle_control(DSPRPC_CONTROL_LATENCY, (void*)&data, sizeof(data));
|
||||
}
|
||||
}
|
||||
|
||||
void hexagon_nn_global_teardown() { rpcmem_deinit(); }
|
||||
|
||||
bool hexagon_nn_is_device_supported() {
|
||||
return tflite::delegates::getsoc_model().mode != UNSPECIFIED_MODE;
|
||||
}
|
||||
|
||||
int hexagon_nn_hexagon_interface_version() { return kHexagonNNVersion; }
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/external_cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
@ -749,10 +748,22 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
|
||||
ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk);
|
||||
|
||||
TfLiteQuantizationParams quantized;
|
||||
char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'};
|
||||
|
||||
// String tensor with one string of length 3
|
||||
union {
|
||||
char raw_bytes[15];
|
||||
struct {
|
||||
int32_t num_strs;
|
||||
int32_t offsets[2];
|
||||
char str_data[3];
|
||||
} tensor_data;
|
||||
} data;
|
||||
data.tensor_data = {1, {12, 15}, {'A', 'B', 'C'}};
|
||||
|
||||
// Read only string tensor.
|
||||
ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1},
|
||||
quantized, data, 15),
|
||||
quantized, data.raw_bytes,
|
||||
sizeof(data.raw_bytes)),
|
||||
kTfLiteOk);
|
||||
// Read-write string tensor.
|
||||
ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1},
|
||||
|
@ -348,6 +348,14 @@ filegroup(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "portable_gpu_tests",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "libtensorflowlite_jni",
|
||||
srcs = select({
|
||||
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.gpu;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.lite.Interpreter;
|
||||
import org.tensorflow.lite.TestUtils;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class GpuDelegateTest {
|
||||
|
||||
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
try (GpuDelegate delegate = new GpuDelegate()) {
|
||||
assertThat(delegate.getNativeHandle()).isNotEqualTo(0);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInterpreterWithGpu() throws Exception {
|
||||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (GpuDelegate delegate = new GpuDelegate();
|
||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||
interpreter.run(fourD, parsedOutputs);
|
||||
float[] outputOneD = parsedOutputs[0][0][0];
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
}
|
||||
}
|
@ -338,6 +338,7 @@ cc_library(
|
||||
# Depend on ruy regardless of `tflite_with_ruy`. See the comment in
|
||||
# cpu_backend_gemm.h about why ruy is the generic path.
|
||||
"@ruy//ruy",
|
||||
"@ruy//ruy:matrix",
|
||||
"@ruy//ruy:path",
|
||||
"@ruy//ruy/profiler:instrumentation",
|
||||
# We only need to depend on gemmlowp and Eigen when tflite_with_ruy
|
||||
@ -525,6 +526,7 @@ cc_library(
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":cpu_backend_context",
|
||||
":cpu_backend_gemm",
|
||||
":cpu_backend_threadpool",
|
||||
":eigen_support",
|
||||
":kernel_util",
|
||||
|
@ -55,9 +55,6 @@ CpuBackendContext::CpuBackendContext()
|
||||
ruy_context_(new ruy::Context),
|
||||
gemmlowp_context_(new gemmlowp::GemmContext) {
|
||||
SetMaxNumThreads(kDefaultNumThreadpoolThreads);
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
ruy_context_->set_cache_policy(ruy::CachePolicy::kCacheLHSOnNarrowMul);
|
||||
#endif
|
||||
}
|
||||
|
||||
CpuBackendContext::~CpuBackendContext() {}
|
||||
|
@ -29,6 +29,17 @@ namespace cpu_backend_gemm {
|
||||
// Matrix storage order: column-major or row-major.
|
||||
enum class Order { kColMajor, kRowMajor };
|
||||
|
||||
enum class CachePolicy : std::uint8_t {
|
||||
kNeverCache,
|
||||
kCacheIfLargeSpeedup,
|
||||
kAlwaysCache,
|
||||
};
|
||||
|
||||
inline CachePolicy DefaultCachePolicy(bool is_constant_data) {
|
||||
return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup
|
||||
: CachePolicy::kNeverCache;
|
||||
}
|
||||
|
||||
// MatrixParams encapsulates the parameters that Gemm needs about each
|
||||
// matrix, besides the buffer data pointer.
|
||||
// Compare to ruy::Matrix, which also encapsulates the data pointer.
|
||||
@ -47,10 +58,13 @@ struct MatrixParams {
|
||||
// The zero_point, i.e. which Scalar value is to be interpreted as zero.
|
||||
// When Scalar is floating-point, this must be 0.
|
||||
Scalar zero_point = 0;
|
||||
// Indicate whether the underlying data will remain unchanged for
|
||||
// some period of time. Defaults to false, but should be set to true
|
||||
// for unchanging data (e.g. weights buffers in many cases)
|
||||
bool cacheable = false;
|
||||
// When the data pointed to by this matrix is constant data, so that it is
|
||||
// valid to assume that equality of pointers implies equality of data,
|
||||
// a CachePolicy may be used instead of the default kNeverCache,
|
||||
// which will enable ruy to take advantage of this constancy of the data to
|
||||
// cache the packing work, which can be a large speedup in matrix*vector
|
||||
// and other narrow shapes.
|
||||
CachePolicy cache_policy = CachePolicy::kNeverCache;
|
||||
};
|
||||
|
||||
// Enumeration of broad categories of Gemm.
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_RUY_H_
|
||||
|
||||
#include "ruy/matrix.h" // from @ruy
|
||||
#include "ruy/path.h" // from @ruy
|
||||
#include "ruy/ruy.h" // from @ruy
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
@ -25,6 +26,20 @@ namespace tflite {
|
||||
namespace cpu_backend_gemm {
|
||||
namespace detail {
|
||||
|
||||
inline ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy) {
|
||||
switch (cache_policy) {
|
||||
case CachePolicy::kNeverCache:
|
||||
return ruy::CachePolicy::kNeverCache;
|
||||
case CachePolicy::kCacheIfLargeSpeedup:
|
||||
return ruy::CachePolicy::kCacheIfLargeSpeedup;
|
||||
case CachePolicy::kAlwaysCache:
|
||||
return ruy::CachePolicy::kAlwaysCache;
|
||||
default:
|
||||
TFLITE_DCHECK(false);
|
||||
return ruy::CachePolicy::kNeverCache;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Scalar, typename DataPointer>
|
||||
void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||
ruy::Matrix<Scalar>* dst) {
|
||||
@ -37,7 +52,9 @@ void MakeRuyMatrix(const MatrixParams<Scalar>& params, DataPointer data_ptr,
|
||||
// It does care whether we assign to it a Scalar* or a const Scalar*.
|
||||
dst->set_data(data_ptr);
|
||||
dst->set_zero_point(params.zero_point);
|
||||
dst->set_cacheable(params.cacheable);
|
||||
#ifdef TFLITE_WITH_RUY_GEMV
|
||||
dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GemmParamsType, typename RuySpecType>
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
|
||||
#include "tensorflow/lite/kernels/internal/quantization_util.h"
|
||||
|
@ -1042,7 +1042,7 @@ void NeonCpuBackendGemm(const int8_t* input, const int32_t* bias,
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.rows = n_output;
|
||||
lhs_params.cols = n_input;
|
||||
lhs_params.cacheable = true;
|
||||
lhs_params.cache_policy = cpu_backend_gemm::CachePolicy::kCacheIfLargeSpeedup;
|
||||
|
||||
MatrixParams<int8_t> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
|
@ -286,13 +286,15 @@ inline void FullyConnected(
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.rows = input_rows;
|
||||
rhs_params.cols = input_shape.FlatSize() / input_rows;
|
||||
rhs_params.cacheable = params.rhs_cacheable;
|
||||
rhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||
TFLITE_DCHECK_EQ(input_shape.FlatSize(), rhs_params.rows * rhs_params.cols);
|
||||
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.cols = weights_shape.Dims(dims_count - 1);
|
||||
lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
|
||||
lhs_params.cacheable = params.lhs_cacheable;
|
||||
lhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<float> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
|
||||
@ -345,13 +347,15 @@ inline void FullyConnected(
|
||||
lhs_params.cols = filter_cols;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.zero_point = -filter_offset;
|
||||
lhs_params.cacheable = params.lhs_cacheable;
|
||||
lhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
||||
rhs_params.rows = filter_cols;
|
||||
rhs_params.cols = batches;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.zero_point = -input_offset;
|
||||
rhs_params.cacheable = params.rhs_cacheable;
|
||||
rhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<uint8> dst_params;
|
||||
dst_params.rows = filter_rows;
|
||||
dst_params.cols = batches;
|
||||
@ -404,13 +408,15 @@ inline void FullyConnected(
|
||||
lhs_params.cols = accum_depth;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.zero_point = -filter_offset;
|
||||
lhs_params.cacheable = params.lhs_cacheable;
|
||||
lhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<uint8> rhs_params;
|
||||
rhs_params.rows = accum_depth;
|
||||
rhs_params.cols = batches;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.zero_point = -input_offset;
|
||||
rhs_params.cacheable = params.rhs_cacheable;
|
||||
rhs_params.cache_policy =
|
||||
cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable);
|
||||
cpu_backend_gemm::MatrixParams<int16> dst_params;
|
||||
dst_params.rows = output_depth;
|
||||
dst_params.cols = batches;
|
||||
|
@ -69,7 +69,6 @@ cc_library(
|
||||
"xtensa_hifimini/quantize.cc",
|
||||
"xtensa_hifimini/softmax.cc",
|
||||
"xtensa_hifimini/svdf.cc",
|
||||
"xtensa_hifimini/utils.h",
|
||||
],
|
||||
}),
|
||||
hdrs = ["micro_ops.h"],
|
||||
|
@ -48,7 +48,17 @@ cp tensorflow/lite/micro/tools/make/downloads/cmsis/CMSIS/DSP/Include/\
|
||||
arm_math.h mbed-os/cmsis/TARGET_CORTEX_M/arm_math.h
|
||||
```
|
||||
|
||||
This issue will be resolved soon. Now type
|
||||
There's also a dependency to an old cmsis_gcc.h, which you can fix with the
|
||||
following:
|
||||
|
||||
```
|
||||
tensorflow/lite/micro/tools/make/downloads/cmsis/CMSIS/Core/Include/\
|
||||
cmsis_gcc.h mbed-os/cmsis/TARGET_CORTEX_M/cmsis_gcc.h
|
||||
```
|
||||
|
||||
This issue will be resolved soon.
|
||||
|
||||
Now type:
|
||||
|
||||
```
|
||||
mbed compile -m DISCO_F746NG -t GCC_ARM
|
||||
|
@ -145,7 +145,7 @@ TfLiteStatus AverageEvalInt8(TfLiteContext* context, const TfLiteNode* node,
|
||||
ARM_MATH_SUCCESS);
|
||||
#else
|
||||
#pragma message( \
|
||||
"CMSIS-NN optimization for depthwise_conv not available for this target. Using reference kernel.")
|
||||
"CMSIS-NN optimization for avg_pool not available for this target. Using reference kernel.")
|
||||
|
||||
PoolParams op_params;
|
||||
op_params.stride_height = params->stride_height;
|
||||
@ -165,8 +165,8 @@ TfLiteStatus AverageEvalInt8(TfLiteContext* context, const TfLiteNode* node,
|
||||
}
|
||||
|
||||
void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLitePoolParams* params, OpData* data,
|
||||
const TfLiteTensor* input, TfLiteTensor* output) {
|
||||
TfLitePoolParams* params, OpData* data, TfLiteTensor* input,
|
||||
TfLiteTensor* output) {
|
||||
float activation_min, activation_max;
|
||||
CalculateActivationRange(params->activation, &activation_min,
|
||||
&activation_max);
|
||||
@ -187,7 +187,7 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
|
||||
void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLitePoolParams* params, OpData* data,
|
||||
const TfLiteTensor* input, TfLiteTensor* output) {
|
||||
TfLiteTensor* input, TfLiteTensor* output) {
|
||||
int32_t activation_min, activation_max;
|
||||
(void)CalculateActivationRangeQuantized(context, params->activation, output,
|
||||
&activation_min, &activation_max);
|
||||
@ -206,6 +206,74 @@ void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
GetTensorData<uint8_t>(output));
|
||||
}
|
||||
|
||||
TfLiteStatus MaxEvalInt8(TfLiteContext* context, const TfLiteNode* node,
|
||||
const TfLitePoolParams* params, const OpData* data,
|
||||
TfLiteTensor* input, TfLiteTensor* output) {
|
||||
int32_t activation_min, activation_max;
|
||||
(void)CalculateActivationRangeQuantized(context, params->activation, output,
|
||||
&activation_min, &activation_max);
|
||||
|
||||
TFLITE_DCHECK_LE(activation_min, activation_max);
|
||||
|
||||
#if defined(__ARM_FEATURE_DSP)
|
||||
RuntimeShape input_shape = GetTensorShape(input);
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||
|
||||
RuntimeShape output_shape = GetTensorShape(output);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
|
||||
|
||||
const int depth = MatchingDim(input_shape, 3, output_shape, 3);
|
||||
const int input_height = input_shape.Dims(1);
|
||||
const int input_width = input_shape.Dims(2);
|
||||
const int output_height = output_shape.Dims(1);
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int stride_height = params->stride_height;
|
||||
const int stride_width = params->stride_width;
|
||||
|
||||
const int filter_height = params->filter_height;
|
||||
const int filter_width = params->filter_width;
|
||||
const int padding_height = data->padding.height;
|
||||
const int padding_width = data->padding.width;
|
||||
|
||||
int16_t* scratch_buffer = nullptr;
|
||||
|
||||
auto* buffer_idx = reinterpret_cast<int*>(node->user_data);
|
||||
|
||||
if (*buffer_idx > -1) {
|
||||
void* raw = context->GetScratchBuffer(context, *buffer_idx);
|
||||
scratch_buffer = reinterpret_cast<int16_t*>(raw);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE_EQ(
|
||||
context,
|
||||
arm_max_pool_s8_opt(input_height, input_width, output_height,
|
||||
output_width, stride_height, stride_width,
|
||||
filter_height, filter_width, padding_height,
|
||||
padding_width, activation_min, activation_max, depth,
|
||||
GetTensorData<int8_t>(input), scratch_buffer,
|
||||
GetTensorData<int8_t>(output)),
|
||||
ARM_MATH_SUCCESS);
|
||||
#else
|
||||
#pragma message( \
|
||||
"CMSIS-NN optimization for max_pool not available for this target. Using reference kernel.")
|
||||
|
||||
PoolParams op_params;
|
||||
op_params.stride_height = params->stride_height;
|
||||
op_params.stride_width = params->stride_width;
|
||||
op_params.filter_height = params->filter_height;
|
||||
op_params.filter_width = params->filter_width;
|
||||
op_params.padding_values.height = data->padding.height;
|
||||
op_params.padding_values.width = data->padding.width;
|
||||
op_params.quantized_activation_min = activation_min;
|
||||
op_params.quantized_activation_max = activation_max;
|
||||
reference_integer_ops::MaxPool(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
GetTensorShape(output), GetTensorData<int8_t>(output));
|
||||
|
||||
#endif
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
@ -278,7 +346,8 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
|
||||
OpData data;
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
TfLiteTensor* input = &context->tensors[flatbuffers::EndianScalar(
|
||||
node->inputs->data[kInputTensor])];
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data));
|
||||
@ -290,6 +359,9 @@ TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteUInt8:
|
||||
MaxEvalQuantizedUInt8(context, node, params, &data, input, output);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
MaxEvalInt8(context, node, params, &data, input, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
|
||||
TfLiteTypeGetName(input->type));
|
||||
|
@ -33,7 +33,7 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kFilterTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
constexpr int kMaxChannels = 256;
|
||||
constexpr int kMaxChannels = 1024;
|
||||
|
||||
// Conv is quantized along dimension 0:
|
||||
// https://www.tensorflow.org/lite/performance/quantization_spec
|
||||
|
@ -35,7 +35,7 @@ constexpr int kInputTensor = 0;
|
||||
constexpr int kFilterTensor = 1;
|
||||
constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
constexpr int kMaxChannels = 256;
|
||||
constexpr int kMaxChannels = 1024;
|
||||
|
||||
// Depthwise conv is quantized along dimension 3:
|
||||
// https://www.tensorflow.org/lite/performance/quantization_spec
|
||||
|
@ -48,7 +48,7 @@ constexpr int kBiasTensor = 2;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
TfLiteFullyConnectedParams* params,
|
||||
TfLiteFusedActivation activation,
|
||||
TfLiteType data_type, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output,
|
||||
@ -62,7 +62,7 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
|
||||
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
|
||||
data->output_shift = -exponent;
|
||||
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
|
||||
context, params->activation, output, &data->output_activation_min,
|
||||
context, activation, output, &data->output_activation_min,
|
||||
&data->output_activation_max));
|
||||
}
|
||||
return status;
|
||||
@ -85,19 +85,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const OpData& data, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output) {
|
||||
FullyConnectedParams op_params;
|
||||
op_params.input_offset = -input->params.zero_point;
|
||||
op_params.weights_offset = -filter->params.zero_point;
|
||||
op_params.output_offset = output->params.zero_point;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// TODO(b/138810107): Figure out whether output shift should be inverted
|
||||
op_params.output_shift = -data->output_shift;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
reference_integer_ops::FullyConnected(
|
||||
op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
|
||||
@ -108,8 +107,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const OpData& data, const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter, const TfLiteTensor* bias,
|
||||
TfLiteTensor* output) {
|
||||
const int32_t input_offset = -input->params.zero_point;
|
||||
@ -120,11 +118,11 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
op_params.input_offset = input_offset;
|
||||
op_params.weights_offset = filter_offset;
|
||||
op_params.output_offset = output_offset;
|
||||
op_params.output_multiplier = data->output_multiplier;
|
||||
op_params.output_multiplier = data.output_multiplier;
|
||||
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
|
||||
op_params.output_shift = -data->output_shift;
|
||||
op_params.quantized_activation_min = data->output_activation_min;
|
||||
op_params.quantized_activation_max = data->output_activation_max;
|
||||
op_params.output_shift = -data.output_shift;
|
||||
op_params.quantized_activation_min = data.output_activation_min;
|
||||
op_params.quantized_activation_max = data.output_activation_max;
|
||||
|
||||
#define TF_LITE_FULLY_CONNECTED(output_data_type) \
|
||||
reference_ops::FullyConnected( \
|
||||
@ -149,11 +147,11 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteFullyConnectedParams* params, OpData* data,
|
||||
TfLiteFusedActivation activation,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
CalculateActivationRange(activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
tflite::FullyConnectedParams op_params;
|
||||
op_params.float_activation_min = output_activation_min;
|
||||
@ -167,8 +165,9 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
auto* params =
|
||||
reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
TFLITE_DCHECK(node->builtin_data != nullptr);
|
||||
const auto* params =
|
||||
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
|
||||
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
|
||||
@ -176,23 +175,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
TfLiteType data_type = input->type;
|
||||
OpData local_data_object;
|
||||
OpData* data = &local_data_object;
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
|
||||
filter, bias, output, data));
|
||||
OpData data;
|
||||
TF_LITE_ENSURE_STATUS(CalculateOpData(context, params->activation, data_type,
|
||||
input, filter, bias, output, &data));
|
||||
|
||||
// Checks in Prepare ensure input, output and filter types are all the same.
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalFloat(context, node, params, data, input, filter, bias,
|
||||
return EvalFloat(context, node, params->activation, input, filter, bias,
|
||||
output);
|
||||
case kTfLiteInt8:
|
||||
return EvalQuantizedInt8(context, node, params, data, input, filter, bias,
|
||||
return EvalQuantizedInt8(context, node, data, input, filter, bias,
|
||||
output);
|
||||
|
||||
case kTfLiteUInt8:
|
||||
return EvalQuantized(context, node, params, data, input, filter, bias,
|
||||
output);
|
||||
return EvalQuantized(context, node, data, input, filter, bias, output);
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -66,7 +65,7 @@ void ConvPerChannel(const ConvParams& params, const int32* output_multiplier,
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_depth = output_shape.Dims(3);
|
||||
|
||||
ae_p24x2s input_offset_24x2 = AE_CONVERT_INT32_24x2(input_offset);
|
||||
ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
|
||||
ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
|
||||
ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min);
|
||||
ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max);
|
||||
@ -150,9 +149,6 @@ void ConvPerChannel(const ConvParams& params, const int32* output_multiplier,
|
||||
acc_24x2, output_multiplier[out_channel],
|
||||
output_shift[out_channel]);
|
||||
|
||||
// Shift from 48bit aligned to 32bit:
|
||||
acc_56 = AE_Q56S_SLAI(acc_56, 16);
|
||||
|
||||
// Add output offset, cap activation, and assign to the output:
|
||||
acc_56 = AE_ADDQ56(acc_56, output_offset_56);
|
||||
acc_56 = AE_MINQ56S(acc_56, output_activation_max_56);
|
||||
@ -178,7 +174,7 @@ inline void Conv1x32Input32x32Filter(
|
||||
const RuntimeShape& filter_shape, const int8* filter_data,
|
||||
const RuntimeShape& bias_shape, const int32* bias_data,
|
||||
const RuntimeShape& output_shape, int8* output_data) {
|
||||
ae_p24x2s input_offset_24x2 = AE_CONVERT_INT32_24x2(input_offset);
|
||||
ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
|
||||
ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
|
||||
ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max);
|
||||
ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min);
|
||||
@ -227,13 +223,10 @@ inline void Conv1x32Input32x32Filter(
|
||||
acc_56 = AE_Q56S_SLAI(acc_56, 8);
|
||||
ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56);
|
||||
|
||||
// Apply quantized multiplier and accumulate result at 48bit
|
||||
// alignment:
|
||||
// Apply quantized multiplier and accumulate result at 48bit alignment.
|
||||
// Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier.
|
||||
acc_56 = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier(
|
||||
acc_24x2, output_multiplier[ch], output_shift[ch]);
|
||||
|
||||
// Shift from 48bit aligned to 32bit:
|
||||
acc_56 = AE_Q56S_SLAI(acc_56, 16);
|
||||
acc_24x2, output_multiplier[ch] >> 8, output_shift[ch]);
|
||||
|
||||
// Add output offset, cap activation, and assign to the output:
|
||||
acc_56 = AE_ADDQ56(acc_56, output_offset_56);
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -69,7 +68,7 @@ inline void DepthwiseConvPerChannel(
|
||||
const int output_width = output_shape.Dims(2);
|
||||
const int output_depth = output_shape.Dims(3);
|
||||
|
||||
ae_p24x2s input_offset_24x2 = AE_CONVERT_INT32_24x2(input_offset);
|
||||
ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
|
||||
ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
|
||||
ae_q56s output_activation_min_56 = AE_CVTQ48A32S(output_activation_min);
|
||||
ae_q56s output_activation_max_56 = AE_CVTQ48A32S(output_activation_max);
|
||||
@ -114,14 +113,14 @@ inline void DepthwiseConvPerChannel(
|
||||
// shift into 24bit space. Note: value is duplicated in the HH
|
||||
// and LL register - but all calculations are done on the HH
|
||||
// side.
|
||||
ae_p24x2s input_val_24x2 = AE_CONVERT_INT32_24x2(input_val);
|
||||
ae_p24x2s input_val_24x2 = AE_MOVPA24(input_val);
|
||||
|
||||
// Add input offset (24bit aligned):
|
||||
input_val_24x2 =
|
||||
AE_P24S_ADDS_P24X2S(input_val_24x2, input_offset_24x2);
|
||||
|
||||
// Load filter 8bit value into 24bit alignment:
|
||||
ae_p24x2s filter_val_24x2 = AE_CONVERT_INT32_24x2(filter_val);
|
||||
ae_p24x2s filter_val_24x2 = AE_MOVPA24(filter_val);
|
||||
|
||||
// Multiply and accumulate the HH side of each 24x24 PR
|
||||
// register:
|
||||
@ -150,9 +149,6 @@ inline void DepthwiseConvPerChannel(
|
||||
acc_24x2, output_multiplier[output_channel],
|
||||
output_shift[output_channel]);
|
||||
|
||||
// Shift from 48bit aligned to 32bit:
|
||||
acc_56 = AE_Q56S_SLAI(acc_56, 16);
|
||||
|
||||
// Add output offset, cap activation, and assign to the output:
|
||||
acc_56 = AE_ADDQ56(acc_56, output_offset_56);
|
||||
acc_56 = AE_MINQ56S(acc_56, output_activation_max_56);
|
||||
@ -181,9 +177,10 @@ inline void DepthwiseConv4x32MatchingInputAndFilter(
|
||||
const RuntimeShape& filter_shape, const int8* filter_data,
|
||||
const RuntimeShape& bias_shape, const int32* bias_data,
|
||||
const RuntimeShape& output_shape, int8* output_data) {
|
||||
const int32_t mult = output_multiplier[0];
|
||||
// Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier.
|
||||
const int32_t mult = output_multiplier[0] >> 8;
|
||||
const int32_t shift = output_shift[0];
|
||||
ae_p24x2s input_offset_24x2 = AE_CONVERT_INT32_24x2(input_offset);
|
||||
ae_p24x2s input_offset_24x2 = AE_MOVPA24(input_offset);
|
||||
ae_q56s output_offset_56 = AE_CVTQ48A32S(output_offset);
|
||||
ae_q56s output_activation_min_56 = AE_CVTQ48A32S(quantized_activation_min);
|
||||
ae_q56s output_activation_max_56 = AE_CVTQ48A32S(quantized_activation_max);
|
||||
@ -270,10 +267,6 @@ inline void DepthwiseConv4x32MatchingInputAndFilter(
|
||||
block_1_acc = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier(
|
||||
acc_24x2_1, mult, shift);
|
||||
|
||||
// Shift from 48bit aligned to 32bit:
|
||||
block_0_acc = AE_Q56S_SLAI(block_0_acc, 16);
|
||||
block_1_acc = AE_Q56S_SLAI(block_1_acc, 16);
|
||||
|
||||
// Add output offset, cap activation, and assign to the output:
|
||||
block_0_acc = AE_ADDQ56(block_0_acc, output_offset_56);
|
||||
block_1_acc = AE_ADDQ56(block_1_acc, output_offset_56);
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -31,80 +30,9 @@ namespace micro {
|
||||
namespace xtensa {
|
||||
namespace hifimini {
|
||||
|
||||
//
|
||||
// Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit
|
||||
// aligned value in the QR register.
|
||||
//
|
||||
inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x,
|
||||
int32_t quantized_multiplier,
|
||||
int shift) {
|
||||
// These boolean factors will carry an additional 2^8 (e.g 256) factor
|
||||
// throughout the equation to cover the missing 8 bits of precision when a
|
||||
// 32bit integer is outside the bounds of INT24. The additional scaling factor
|
||||
// will be adjusted after the final multiplication in this method.
|
||||
//
|
||||
// The Q-notation comments in this method describe the calculations that take
|
||||
// place when both |x| and the shifted value of |1| overflow the INT24 limits.
|
||||
bool x_exceeds_24bits = (x <= INT24_MIN || x >= INT24_MAX);
|
||||
bool shift_exceeds_24bits = false;
|
||||
|
||||
// Q31.0 -> Q23.0 / 2^8
|
||||
ae_p24x2s x_24x2 = AE_CONVERT_INT32_24x2(x);
|
||||
|
||||
if (shift > 0) {
|
||||
int shifted = 1 << shift;
|
||||
if (shifted <= INT24_MIN || shifted >= INT24_MAX) {
|
||||
shift_exceeds_24bits = true;
|
||||
}
|
||||
|
||||
// Load the shifted value into the PR register:
|
||||
// Q31.0 -> Q23.0 / 2^8
|
||||
ae_p24x2s shifted_24x2 = AE_CONVERT_INT32_24x2(shifted);
|
||||
|
||||
// (Q23.0 / 2^8) * (Q23.0 / 2^8) = Q47.0 / 2^16
|
||||
ae_q56s sum_56 = AE_MULP24S_HH(x_24x2, shifted_24x2);
|
||||
|
||||
// Shift left into 24bit space:
|
||||
// ((Q47.0 / 2^16) << 24) = Q23.24 / 2^16
|
||||
sum_56 = AE_Q56S_SLAI(sum_56, 24);
|
||||
|
||||
// Truncate and place on the PR register:
|
||||
// (Q23.24 / 2^16) -> Q23.0 / 2^16
|
||||
x_24x2 = AE_TRUNCP24Q48(sum_56);
|
||||
}
|
||||
|
||||
// Load the quantized multiplier into the PR register.
|
||||
// NOTE: This method assumes that this param has been calculated for 24bit
|
||||
// space - not 32bits.
|
||||
// Q0.31 -> Q0.23
|
||||
ae_p24x2s quantized_multiplier_24x2 =
|
||||
AE_CONVERT_INT32_24x2(quantized_multiplier);
|
||||
|
||||
// Adjust for the additional 8 bits of lost precision throughout this
|
||||
// function:
|
||||
int shift_amount = 23;
|
||||
if (x_exceeds_24bits) {
|
||||
shift_amount = shift_amount - 8;
|
||||
}
|
||||
if (shift_exceeds_24bits) {
|
||||
shift_amount = shift_amount - 8;
|
||||
}
|
||||
|
||||
// Find the product of x and the quantized_multiplier and right shift
|
||||
// to 48bit aligned.
|
||||
// (Q23.0 / 2^16) * Q23.0 = Q47.0 / 2^16
|
||||
// (Q47.0 / 2^16) >> 7 = Q47.0
|
||||
ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2);
|
||||
if (shift_amount > 0) {
|
||||
result_56 = AE_Q56S_SRA(result_56, shift_amount);
|
||||
}
|
||||
|
||||
if (shift < 0) {
|
||||
// Handle any negative shift directly on the 48 bit value.
|
||||
result_56 = AE_Q56S_SRA(result_56, -shift);
|
||||
}
|
||||
return result_56;
|
||||
}
|
||||
// INT24 MIN/MAX
|
||||
#define INT24_MIN -8388608
|
||||
#define INT24_MAX 8388607
|
||||
|
||||
//
|
||||
// Multiply 24bit value by a quantized multiplier (w/ shift) and returns a 48bit
|
||||
@ -113,62 +41,62 @@ inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x,
|
||||
inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2,
|
||||
int32_t quantized_multiplier,
|
||||
int shift) {
|
||||
// NOTE: x_24x2 = Q23.0
|
||||
|
||||
// This is an optimized version of a 32 bit MultiplyByQuantizedMultiplier
|
||||
// operation of TFLite. Sometimes, the shifted value of |x_24x2| can exceed
|
||||
// the limits of INT24, which requires |AE_CONVERT_INT32_24x2()| to load the
|
||||
// left-most 24 bits of a 32bit integer. When this occurs, all Q values here
|
||||
// carry an additional division of 2^8 to account for this loss in precision.
|
||||
// This division will be applied to the final shift after multiplication.
|
||||
// A value with 1 sign bit, N integer bits and M fractional bits is
|
||||
// represented as QN+1.M since the sign bit is included in the integer bits.
|
||||
//
|
||||
// The Q notation in this method explains the values represented in each
|
||||
// variable, along with an implicit division since the quantized_multiplier
|
||||
// represents a value between 0.5 and 1.0 (Q1.X-1 where X is the bit precision
|
||||
// of the type).
|
||||
//
|
||||
// The Q-notation comments in this method describe the calculations that take
|
||||
// place when both |x| and the shifted value of |1| overflow the INT24 limits.
|
||||
bool shift_exceeds_24bits = false;
|
||||
|
||||
ae_p24x2s x_shifted_24x2 = x_24x2;
|
||||
if (shift > 0) {
|
||||
int shifted = 1 << shift;
|
||||
if (shifted <= INT24_MIN || shifted >= INT24_MAX) {
|
||||
shift_exceeds_24bits = true;
|
||||
}
|
||||
// Load the shifted value into the PR register:
|
||||
// Q31.0 -> Q23.0 / 2^8
|
||||
ae_p24x2s shifted_24x2 = AE_CONVERT_INT32_24x2(shifted);
|
||||
|
||||
// Q23.0 * (Q23.0 / 2^8) = Q47.0 / 2^8
|
||||
ae_q56s sum_56 = AE_MULP24S_HH(x_24x2, shifted_24x2);
|
||||
|
||||
// Shift left into 24bit space:
|
||||
// ((Q47.0 / 2^8) << 24) = Q23.24 / 2^8
|
||||
sum_56 = AE_Q56S_SLAI(sum_56, 24);
|
||||
|
||||
// Truncate and place on the PR register:
|
||||
// (Q23.24 / 2^8) -> Q23.0 / 2^8
|
||||
x_shifted_24x2 = AE_ROUNDSP24Q48SYM(sum_56);
|
||||
}
|
||||
|
||||
// Load the quantized multiplier into the PR register.
|
||||
// NOTE: This method assumes that this param has been calculated for 24bit
|
||||
// space - not 32bits.
|
||||
// Q0.31 -> Q0.23
|
||||
ae_p24x2s quantized_multiplier_24x2 =
|
||||
AE_CONVERT_INT32_24x2(quantized_multiplier);
|
||||
// Q32.0 / 2^23 -> Q24.0 / 2^23 representing a Q1.23 multiplier.
|
||||
ae_p24x2s quantized_multiplier_24x2 = AE_MOVPA24(quantized_multiplier);
|
||||
// Shift right by 23 - 16 bits minus the specified shift. This is because we
|
||||
// keep 16 fractional bits until the end to perform rounding. Subtract shift
|
||||
// since shift is a left shift, and the 23-16 is a right shift.
|
||||
int shift_amount = 7 - shift;
|
||||
|
||||
// Find the product of x and the quantized_multiplier and right shift
|
||||
// to 48bit aligned.
|
||||
// NOTE: Adjust for the additional 8 bits of lost precision throughout this
|
||||
// function:
|
||||
// (Q23.0 / 2^8) * Q23.0 = Q47.0 / 2^8
|
||||
// (Q47.0 / 2^8) >> 7 = Q47.0
|
||||
ae_q56s result = AE_MULP24S_HH(x_shifted_24x2, quantized_multiplier_24x2);
|
||||
result = AE_Q56S_SRA(result, shift_exceeds_24bits ? 15 : 23);
|
||||
// Find the product of x and the quantized_multiplier.
|
||||
// Q24.0 / 2^23 * Q24.0 = Q48.0 / 2^23
|
||||
// Q48.0 / 2^23 >> 7 = Q48.0 / 2^16
|
||||
ae_q56s result_56 = AE_MULP24S_HH(x_24x2, quantized_multiplier_24x2);
|
||||
|
||||
if (shift < 0) {
|
||||
// Handle any negative shift directly on the 48 bit value.
|
||||
result = AE_Q56S_SRA(result, -shift);
|
||||
// Shift right if shift amount is positive, left if shift amount is negative.
|
||||
if (shift_amount >= 0) {
|
||||
result_56 = AE_Q56S_SRA(result_56, shift_amount);
|
||||
} else {
|
||||
result_56 = AE_Q56S_SLA(result_56, -shift_amount);
|
||||
}
|
||||
return result;
|
||||
|
||||
// Round off the bottom 16 bits.
|
||||
// Q48.0 / 2^16 -> Q32.0 aligned to 48 bits.
|
||||
result_56 = AE_ROUNDSQ32SYM(result_56);
|
||||
return result_56;
|
||||
}
|
||||
|
||||
//
|
||||
// Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit
|
||||
// aligned value in the QR register.
|
||||
//
|
||||
inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x,
|
||||
int32_t quantized_multiplier,
|
||||
int shift) {
|
||||
// Convert x into a 2x24bit PR register file. If x is outside the numerical
|
||||
// limits of a 24bit integer, the "fractional" or lower 8bits are discarded.
|
||||
// If x is within the range of a 24 bit integer, the "signed" or upper 8bits
|
||||
// are discarded.
|
||||
ae_p24x2s x_24x2;
|
||||
if (x > INT24_MIN && x < INT24_MAX) {
|
||||
x_24x2 = AE_MOVPA24(x);
|
||||
} else {
|
||||
x_24x2 = static_cast<ae_p24s>(*reinterpret_cast<ae_p24f*>(&x));
|
||||
shift += 8;
|
||||
}
|
||||
|
||||
return MultiplyByQuantizedMultiplier(x_24x2, quantized_multiplier, shift);
|
||||
}
|
||||
|
||||
//
|
||||
@ -193,6 +121,8 @@ inline void QuantizeMultiplier(float multiplier, int32_t* quantized_multiplier,
|
||||
}
|
||||
TFLITE_CHECK_LE(q_fixed, INT24_MAX);
|
||||
|
||||
// Ensure shift does not exceed 24-bit range.
|
||||
TFLITE_CHECK_LE(*shift, 23);
|
||||
if (*shift < -23) {
|
||||
*shift = 0;
|
||||
q_fixed = 0;
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -108,9 +107,6 @@ inline void FullyConnected(
|
||||
sum_56 = MultiplyByQuantizedMultiplier(sum_24x2, output_multiplier,
|
||||
output_shift);
|
||||
|
||||
// Align from 48bit to 32bit on the QR register:
|
||||
sum_56 = AE_Q56S_SLAI(sum_56, 16);
|
||||
|
||||
// Add output_offset and cap min/max values:
|
||||
sum_56 = AE_ADDQ56(sum_56, output_offset_56);
|
||||
sum_56 = AE_MINQ56S(sum_56, output_activation_max_56);
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -43,7 +42,7 @@ void AffineQuantize(int scale_multiplier,
|
||||
|
||||
const ae_p16x2s* input_data_ptr = (const ae_p16x2s*)(input_data - 2);
|
||||
|
||||
ae_p24x2s scale_multiplier_24x2 = AE_CONVERT_INT32_24x2(scale_multiplier);
|
||||
ae_p24x2s scale_multiplier_24x2 = AE_MOVPA24(scale_multiplier);
|
||||
|
||||
int iters = flat_size / 2;
|
||||
for (int i = 0; i < iters; i++) {
|
||||
|
@ -25,8 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/kernels/op_macros.h"
|
||||
#include "tensorflow/lite/micro/kernels/activation_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h"
|
||||
#include "tensorflow/lite/micro/kernels/xtensa_hifimini/utils.h"
|
||||
#include "tensorflow/lite/micro/micro_utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
@ -99,7 +97,7 @@ void EvalIntegerSVDF(
|
||||
|
||||
ae_q56s output_int16_max_56 = AE_CVTQ48A32S(INT16_MAX);
|
||||
ae_q56s output_int16_min_56 = AE_CVTQ48A32S(INT16_MIN);
|
||||
ae_p24x2s input_zp_24x2 = AE_CONVERT_INT32_24x2(input_zp);
|
||||
ae_p24x2s input_zp_24x2 = AE_MOVPA24(input_zp);
|
||||
|
||||
for (int b = 0; b < n_batch; b++) {
|
||||
const int8_t* weight_feature_ptr = weight_feature - 2;
|
||||
@ -140,8 +138,6 @@ void EvalIntegerSVDF(
|
||||
tflite::ops::micro::xtensa::hifimini::MultiplyByQuantizedMultiplier(
|
||||
dot_prod_24x2, scale_1_a, scale_1_b);
|
||||
|
||||
// Align from 48bit to 32bit on the QR register
|
||||
dot_prod_56 = AE_Q56S_SLAI(dot_prod_56, 16);
|
||||
// Cap min/max and convert to int32:
|
||||
dot_prod_56 = AE_MAXQ56S(dot_prod_56, output_int16_min_56);
|
||||
dot_prod_56 = AE_MINQ56S(dot_prod_56, output_int16_max_56);
|
||||
@ -232,8 +228,6 @@ void EvalIntegerSVDF(
|
||||
ae_q56s x_56 =
|
||||
tflite::ops::micro::xtensa::hifimini::MultiplyByQuantizedMultiplier(
|
||||
scratch_output_tensor[i], scale_2_a, scale_2_b);
|
||||
// Align from 48bit to 32bit on the QR register:
|
||||
x_56 = AE_Q56S_SLAI(x_56, 16);
|
||||
// Add output adjustment:
|
||||
x_56 = AE_ADDQ56(x_56, output_zp_56);
|
||||
// Cap min/max and convert to int32 (already aligned to 32bit):
|
||||
|
@ -1,42 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_UTILS_H_
|
||||
#define TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_UTILS_H_
|
||||
|
||||
#include <xtensa/tie/xt_hifi2.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
// INT24 MIN/MAX
|
||||
#define INT24_MIN -8388608
|
||||
#define INT24_MAX 8388607
|
||||
|
||||
// Converts an int32 value into a 2x24bit PR register file. If the int32 value
|
||||
// is outside the numerical limits of a 24bit integer, the "fractional" or lower
|
||||
// 8bits are discarded. If the value is within the range of a 24 bit integer,
|
||||
// the "signed" or upper 8bits are discarded.
|
||||
inline ae_p24x2s AE_CONVERT_INT32_24x2(int32_t v) {
|
||||
if (v > INT24_MIN && v < INT24_MAX) {
|
||||
return *reinterpret_cast<ae_p24s*>(&v);
|
||||
} else {
|
||||
return static_cast<ae_p24s>(*reinterpret_cast<ae_p24f*>(&v));
|
||||
}
|
||||
}
|
||||
|
||||
// Shifts a 48bit accumulator value into 32bit space and returns the value.
|
||||
#define AE_CONVERT_Q56_INT32(v) AE_TRUNCA32Q48(AE_Q56S_SLAI(v, 16))
|
||||
|
||||
#endif // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_UTILS_H_
|
@ -18,6 +18,33 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace testing {
|
||||
|
||||
uint8_t F2Q(float value, float min, float max) {
|
||||
int32_t result = ZeroPointFromMinMax<uint8_t>(min, max) +
|
||||
(value / ScaleFromMinMax<uint8_t>(min, max)) + 0.5f;
|
||||
if (result < std::numeric_limits<uint8_t>::min()) {
|
||||
result = std::numeric_limits<uint8_t>::min();
|
||||
}
|
||||
if (result > std::numeric_limits<uint8_t>::max()) {
|
||||
result = std::numeric_limits<uint8_t>::max();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts a float value into a signed eight-bit quantized value.
|
||||
int8_t F2QS(float value, float min, float max) {
|
||||
return F2Q(value, min, max) + std::numeric_limits<int8_t>::min();
|
||||
}
|
||||
|
||||
int32_t F2Q32(float value, float scale) {
|
||||
double quantized = value / scale;
|
||||
if (quantized > std::numeric_limits<int32_t>::max()) {
|
||||
quantized = std::numeric_limits<int32_t>::max();
|
||||
} else if (quantized < std::numeric_limits<int32_t>::min()) {
|
||||
quantized = std::numeric_limits<int32_t>::min();
|
||||
}
|
||||
return static_cast<int>(quantized);
|
||||
}
|
||||
|
||||
// TODO(b/141330728): Move this method elsewhere as part clean up.
|
||||
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
|
||||
ErrorReporter* error_reporter, TfLiteContext* context) {
|
||||
@ -41,5 +68,139 @@ void PopulateContext(TfLiteTensor* tensors, int tensors_size,
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteTensor CreateFloatTensor(std::initializer_list<float> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable) {
|
||||
return CreateFloatTensor(data.begin(), dims, name, is_variable);
|
||||
}
|
||||
|
||||
TfLiteTensor CreateBoolTensor(std::initializer_list<bool> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable) {
|
||||
return CreateBoolTensor(data.begin(), dims, name, is_variable);
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims,
|
||||
const char* name, float min, float max,
|
||||
bool is_variable) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteUInt8;
|
||||
result.data.uint8 = const_cast<uint8_t*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {ScaleFromMinMax<uint8_t>(min, max),
|
||||
ZeroPointFromMinMax<uint8_t>(min, max)};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(uint8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(std::initializer_list<uint8_t> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
float min, float max, bool is_variable) {
|
||||
return CreateQuantizedTensor(data.begin(), dims, name, min, max, is_variable);
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
|
||||
const char* name, float min, float max,
|
||||
bool is_variable) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteInt8;
|
||||
result.data.int8 = const_cast<int8_t*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {ScaleFromMinMax<int8_t>(min, max),
|
||||
ZeroPointFromMinMax<int8_t>(min, max)};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(std::initializer_list<int8_t> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
float min, float max, bool is_variable) {
|
||||
return CreateQuantizedTensor(data.begin(), dims, name, min, max, is_variable);
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(float* data, uint8_t* quantized_data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable) {
|
||||
TfLiteTensor result;
|
||||
SymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.uint8 = quantized_data;
|
||||
result.type = kTfLiteUInt8;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 128;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(uint8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable) {
|
||||
TfLiteTensor result;
|
||||
SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.int8 = quantized_data;
|
||||
result.type = kTfLiteInt8;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 0;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(float* data, int16_t* quantized_data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable) {
|
||||
TfLiteTensor result;
|
||||
SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.i16 = quantized_data;
|
||||
result.type = kTfLiteInt16;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 0;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int16_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantized32Tensor(const int32_t* data, TfLiteIntArray* dims,
|
||||
const char* name, float scale,
|
||||
bool is_variable) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteInt32;
|
||||
result.data.i32 = const_cast<int32_t*>(data);
|
||||
result.dims = dims;
|
||||
// Quantized int32 tensors always have a zero point of 0, since the range of
|
||||
// int32 values is large, and because zero point costs extra cycles during
|
||||
// processing.
|
||||
result.params = {scale, 0};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int32_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
float scale, bool is_variable) {
|
||||
return CreateQuantized32Tensor(data.begin(), dims, name, scale, is_variable);
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tflite
|
||||
|
@ -65,182 +65,65 @@ inline int ZeroPointFromMinMax(const float min, const float max) {
|
||||
}
|
||||
|
||||
// Converts a float value into an unsigned eight-bit quantized value.
|
||||
inline uint8_t F2Q(const float value, const float min, const float max) {
|
||||
int32_t result = ZeroPointFromMinMax<uint8_t>(min, max) +
|
||||
(value / ScaleFromMinMax<uint8_t>(min, max)) + 0.5f;
|
||||
if (result < std::numeric_limits<uint8_t>::min()) {
|
||||
result = std::numeric_limits<uint8_t>::min();
|
||||
}
|
||||
if (result > std::numeric_limits<uint8_t>::max()) {
|
||||
result = std::numeric_limits<uint8_t>::max();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
uint8_t F2Q(float value, float min, float max);
|
||||
|
||||
// Converts a float value into a signed eight-bit quantized value.
|
||||
inline int8_t F2QS(const float value, const float min, const float max) {
|
||||
return F2Q(value, min, max) + std::numeric_limits<int8_t>::min();
|
||||
}
|
||||
int8_t F2QS(const float value, const float min, const float max);
|
||||
|
||||
// Converts a float value into a signed thirty-two-bit quantized value. Note
|
||||
// that values close to max int and min int may see significant error due to
|
||||
// a lack of floating point granularity for large values.
|
||||
inline int32_t F2Q32(const float value, const float scale) {
|
||||
double quantized = value / scale;
|
||||
if (quantized > std::numeric_limits<int32_t>::max()) {
|
||||
quantized = std::numeric_limits<int32_t>::max();
|
||||
} else if (quantized < std::numeric_limits<int32_t>::min()) {
|
||||
quantized = std::numeric_limits<int32_t>::min();
|
||||
}
|
||||
return static_cast<int>(quantized);
|
||||
}
|
||||
int32_t F2Q32(const float value, const float scale);
|
||||
|
||||
// TODO(b/141330728): Move this method elsewhere as part clean up.
|
||||
void PopulateContext(TfLiteTensor* tensors, int tensors_size,
|
||||
ErrorReporter* error_reporter, TfLiteContext* context);
|
||||
|
||||
inline TfLiteTensor CreateFloatTensor(std::initializer_list<float> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false) {
|
||||
return CreateFloatTensor(data.begin(), dims, name, is_variable);
|
||||
}
|
||||
TfLiteTensor CreateFloatTensor(std::initializer_list<float> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false);
|
||||
|
||||
inline TfLiteTensor CreateBoolTensor(std::initializer_list<bool> data,
|
||||
TfLiteTensor CreateBoolTensor(std::initializer_list<bool> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims,
|
||||
const char* name, float min, float max,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(std::initializer_list<uint8_t> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
float min, float max,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims,
|
||||
const char* name, float min, float max,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(std::initializer_list<int8_t> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
float min, float max,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(float* data, uint8_t* quantized_data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantizedTensor(float* data, int16_t* quantized_data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantized32Tensor(const int32_t* data, TfLiteIntArray* dims,
|
||||
const char* name, float scale,
|
||||
bool is_variable = false);
|
||||
|
||||
TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
|
||||
TfLiteIntArray* dims, const char* name,
|
||||
bool is_variable = false) {
|
||||
return CreateBoolTensor(data.begin(), dims, name, is_variable);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(const uint8_t* data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float min,
|
||||
float max, bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteUInt8;
|
||||
result.data.uint8 = const_cast<uint8_t*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {ScaleFromMinMax<uint8_t>(min, max),
|
||||
ZeroPointFromMinMax<uint8_t>(min, max)};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(uint8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = false;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(std::initializer_list<uint8_t> data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float min,
|
||||
float max, bool is_variable = false) {
|
||||
return CreateQuantizedTensor(data.begin(), dims, name, min, max, is_variable);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(const int8_t* data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float min,
|
||||
float max, bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteInt8;
|
||||
result.data.int8 = const_cast<int8_t*>(data);
|
||||
result.dims = dims;
|
||||
result.params = {ScaleFromMinMax<int8_t>(min, max),
|
||||
ZeroPointFromMinMax<int8_t>(min, max)};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(std::initializer_list<int8_t> data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float min,
|
||||
float max, bool is_variable = false) {
|
||||
return CreateQuantizedTensor(data.begin(), dims, name, min, max, is_variable);
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(float* data, uint8_t* quantized_data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name,
|
||||
bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
SymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.uint8 = quantized_data;
|
||||
result.type = kTfLiteUInt8;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 128;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(uint8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(float* data, int8_t* quantized_data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name,
|
||||
bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.int8 = quantized_data;
|
||||
result.type = kTfLiteInt8;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 0;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int8_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantizedTensor(float* data, int16_t* quantized_data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name,
|
||||
bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale);
|
||||
result.data.i16 = quantized_data;
|
||||
result.type = kTfLiteInt16;
|
||||
result.dims = dims;
|
||||
result.params.zero_point = 0;
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int16_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float scale,
|
||||
bool is_variable = false) {
|
||||
TfLiteTensor result;
|
||||
result.type = kTfLiteInt32;
|
||||
result.data.i32 = const_cast<int32_t*>(data);
|
||||
result.dims = dims;
|
||||
// Quantized int32 tensors always have a zero point of 0, since the range of
|
||||
// int32 values is large, and because zero point costs extra cycles during
|
||||
// processing.
|
||||
result.params = {scale, 0};
|
||||
result.allocation_type = kTfLiteMemNone;
|
||||
result.bytes = ElementCount(*dims) * sizeof(int32_t);
|
||||
result.allocation = nullptr;
|
||||
result.name = name;
|
||||
result.is_variable = is_variable;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
|
||||
TfLiteIntArray* dims,
|
||||
const char* name, float scale,
|
||||
bool is_variable = false) {
|
||||
return CreateQuantized32Tensor(data.begin(), dims, name, scale, is_variable);
|
||||
}
|
||||
float scale, bool is_variable = false);
|
||||
|
||||
template <typename input_type = int32_t,
|
||||
TfLiteType tensor_input_type = kTfLiteInt32>
|
||||
|
@ -28,8 +28,8 @@ LEON_BCC2_MD5 := "cdf78082be4882da2a92c9baa82fe765"
|
||||
TSIM_URL := "https://www.gaisler.com/anonftp/tsim/tsim-eval-2.0.63.tar.gz"
|
||||
TSIM_MD5 := "afa0095d3ed989a949e1467f94e41d2f"
|
||||
|
||||
CMSIS_URL := "https://github.com/ARM-software/CMSIS_5/archive/3d8235079ade1e4df06f91be65e0309cc45e1952.zip"
|
||||
CMSIS_MD5 := "f3e93203e875caf4ba6aff0bccd95d85"
|
||||
CMSIS_URL := "https://github.com/ARM-software/CMSIS_5/archive/8a4db53f69da06e97565fe2f2e8926d193a5759d.zip"
|
||||
CMSIS_MD5 := "e9864fb71b65adc4f7d92a9dea6e1aab"
|
||||
|
||||
AM_SDK_URL := "http://s3.asia.ambiqmicro.com/downloads/AmbiqSuite-Rel2.2.0.zip"
|
||||
AM_SDK_MD5 := "7605fa2d4d97e6bb7a1190c92b66b597"
|
||||
@ -56,8 +56,8 @@ SIFIVE_FE310_LIB_MD5 := "06ee24c4956f8e21670ab3395861fe64"
|
||||
KISSFFT_URL="https://github.com/mborgerding/kissfft/archive/v130.zip"
|
||||
KISSFFT_MD5="438ba1fef5783cc5f5f201395cc477ca"
|
||||
|
||||
RUY_URL="https://github.com/google/ruy/archive/9f53ba413e6fc879236dcaa3e008915973d67a4f.zip"
|
||||
RUY_MD5="ce2c2444cced9dcf6ca6bc908061faa8"
|
||||
RUY_URL="https://github.com/google/ruy/archive/4bdb31ab484e624deef9620ecde2156ca17f6567.zip"
|
||||
RUY_MD5="191d6a173a4fde9742f597f0f4e1f08b"
|
||||
|
||||
CIFAR10_DATASET_URL="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
||||
CIFAR10_DATASET_MD5="c32a1d4ab5d03f1284b67883e8d87530"
|
||||
|
@ -33,13 +33,22 @@ TEST(StringUtil, TestStringUtil) {
|
||||
t1->type = kTfLiteString;
|
||||
t1->allocation_type = kTfLiteDynamic;
|
||||
|
||||
char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'X', 'Y', 'Z'};
|
||||
// String tensor with one string of length 3
|
||||
union {
|
||||
char raw_bytes[15];
|
||||
struct {
|
||||
int32_t num_strs;
|
||||
int32_t offsets[2];
|
||||
char str_data[3];
|
||||
} tensor_data;
|
||||
} data;
|
||||
data.tensor_data = {1, {12, 15}, {'X', 'Y', 'Z'}};
|
||||
|
||||
TfLiteQuantization quant;
|
||||
quant.type = kTfLiteNoQuantization;
|
||||
quant.params = nullptr;
|
||||
interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, quant,
|
||||
data, 15);
|
||||
interpreter.SetTensorParametersReadOnly(
|
||||
2, kTfLiteString, "", {1}, quant, data.raw_bytes, sizeof(data.raw_bytes));
|
||||
TfLiteTensor* t2 = interpreter.tensor(2);
|
||||
interpreter.AllocateTensors();
|
||||
|
||||
|
@ -37,8 +37,8 @@ EIGEN_URL="$(grep -o 'https.*gitlab.com/libeigen/eigen/-/archive/.*tar\.gz' "${B
|
||||
EIGEN_SHA="$(eval echo $(grep '# SHARED_EIGEN_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))"
|
||||
GEMMLOWP_URL="$(grep -o 'https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
|
||||
GEMMLOWP_SHA="$(eval echo $(grep '# SHARED_GEMMLOWP_SHA' "${BZL_FILE_PATH}" | grep -o '\".*\"'))"
|
||||
RUY_URL="https://github.com/google/ruy/archive/9f53ba413e6fc879236dcaa3e008915973d67a4f.zip"
|
||||
RUY_SHA="fe8345f521bb378745ebdd0f8c5937414849936851d2ec2609774eb2d7098e54"
|
||||
RUY_URL="https://github.com/google/ruy/archive/4bdb31ab484e624deef9620ecde2156ca17f6567.zip"
|
||||
RUY_SHA="51c1492196cdd6fc524dd8b539de5d644bbb436699fab3908585a575e347c789"
|
||||
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
|
||||
GOOGLETEST_SHA="58a6f4277ca2bc8565222b3bbd58a177609e9c488e8a72649359ba51450db7d8"
|
||||
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
|
@ -250,7 +250,7 @@ FormatConverter<T>::FormatConverter(const std::vector<int>& shape,
|
||||
for (int i = 0; i < original_rank; i++) {
|
||||
if (block_dim < block_map_.size() && block_map_[block_dim] == i) {
|
||||
int orig_dim = traversal_order_[original_rank + block_dim];
|
||||
block_size_[i] = sparsity.dim_metadata[orig_dim].dense_size;
|
||||
block_size_[block_dim] = sparsity.dim_metadata[orig_dim].dense_size;
|
||||
blocked_shape_[i] = shape[i] / sparsity.dim_metadata[orig_dim].dense_size;
|
||||
block_dim++;
|
||||
} else {
|
||||
@ -273,9 +273,10 @@ void FormatConverter<T>::Populate(const T* src_data, std::vector<int> indices,
|
||||
}
|
||||
|
||||
for (; i < indices.size(); i++) {
|
||||
int orig_dim = block_map_[traversal_order_[i] - orig_rank];
|
||||
const int block_idx = traversal_order_[i] - orig_rank;
|
||||
const int orig_dim = block_map_[block_idx];
|
||||
orig_idx[orig_dim] =
|
||||
orig_idx[orig_dim] * block_size_[orig_dim] + indices[i];
|
||||
orig_idx[orig_dim] * block_size_[block_idx] + indices[i];
|
||||
}
|
||||
|
||||
data_[GetFlattenedIndex(orig_idx, dense_shape_)] = src_data[*src_data_ptr];
|
||||
|
@ -31,18 +31,18 @@ TEST(FormatConverterTest, SimpleTestD0D1) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0 = {3};
|
||||
const std::vector<int> dm1 = {4};
|
||||
EXPECT_EQ(dm0, dim_metadata[0]);
|
||||
EXPECT_EQ(dm1, dim_metadata[2]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 0, 9, 8, 0, 0, 0, 0, 5, 0, 0, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -55,7 +55,7 @@ TEST(FormatConverterTest, SimpleTestS0D1) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 2};
|
||||
const std::vector<int> dm0_1 = {0, 2};
|
||||
const std::vector<int> dm1 = {4};
|
||||
@ -63,12 +63,12 @@ TEST(FormatConverterTest, SimpleTestS0D1) {
|
||||
EXPECT_EQ(dm0_1, dim_metadata[1]);
|
||||
EXPECT_EQ(dm1, dim_metadata[2]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 0, 9, 8, 5, 0, 0, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -81,7 +81,7 @@ TEST(FormatConverterTest, SimpleTestD0S1) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0 = {3};
|
||||
const std::vector<int> dm1_0 = {0, 3, 3, 5};
|
||||
const std::vector<int> dm1_1 = {0, 2, 3, 0, 3};
|
||||
@ -89,12 +89,12 @@ TEST(FormatConverterTest, SimpleTestD0S1) {
|
||||
EXPECT_EQ(dm1_0, dim_metadata[2]);
|
||||
EXPECT_EQ(dm1_1, dim_metadata[3]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 9, 8, 5, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ TEST(FormatConverterTest, SimpleTestS0S1) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 2};
|
||||
const std::vector<int> dm0_1 = {0, 2};
|
||||
const std::vector<int> dm1_0 = {0, 3, 5};
|
||||
@ -117,12 +117,12 @@ TEST(FormatConverterTest, SimpleTestS0S1) {
|
||||
EXPECT_EQ(dm1_0, dim_metadata[2]);
|
||||
EXPECT_EQ(dm1_1, dim_metadata[3]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 9, 8, 5, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -135,18 +135,18 @@ TEST(FormatConverterTest, SimpleTestD1D0) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0 = {4};
|
||||
const std::vector<int> dm1 = {3};
|
||||
EXPECT_EQ(dm0, dim_metadata[0]);
|
||||
EXPECT_EQ(dm1, dim_metadata[2]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 0, 5, 0, 0, 0, 9, 0, 0, 8, 0, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ TEST(FormatConverterTest, SimpleTestS1D0) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 3};
|
||||
const std::vector<int> dm0_1 = {0, 2, 3};
|
||||
const std::vector<int> dm1 = {3};
|
||||
@ -167,12 +167,12 @@ TEST(FormatConverterTest, SimpleTestS1D0) {
|
||||
EXPECT_EQ(dm0_1, dim_metadata[1]);
|
||||
EXPECT_EQ(dm1, dim_metadata[2]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 0, 5, 9, 0, 0, 8, 0, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -185,7 +185,7 @@ TEST(FormatConverterTest, SimpleTestD1S0) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0 = {4};
|
||||
const std::vector<int> dm1_0 = {0, 2, 2, 3, 5};
|
||||
const std::vector<int> dm1_1 = {0, 2, 0, 0, 2};
|
||||
@ -193,12 +193,12 @@ TEST(FormatConverterTest, SimpleTestD1S0) {
|
||||
EXPECT_EQ(dm1_0, dim_metadata[2]);
|
||||
EXPECT_EQ(dm1_1, dim_metadata[3]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 5, 9, 8, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -211,7 +211,7 @@ TEST(FormatConverterTest, SimpleTestS1S0) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 3};
|
||||
const std::vector<int> dm0_1 = {0, 2, 3};
|
||||
const std::vector<int> dm1_0 = {0, 2, 3, 5};
|
||||
@ -221,12 +221,12 @@ TEST(FormatConverterTest, SimpleTestS1S0) {
|
||||
EXPECT_EQ(dm1_0, dim_metadata[2]);
|
||||
EXPECT_EQ(dm1_1, dim_metadata[3]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 5, 9, 8, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -239,7 +239,7 @@ TEST(FormatConverterTest, 3DTestS0D1S2) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 2};
|
||||
const std::vector<int> dm0_1 = {0, 2};
|
||||
const std::vector<int> dm1 = {2};
|
||||
@ -252,12 +252,12 @@ TEST(FormatConverterTest, 3DTestS0D1S2) {
|
||||
EXPECT_EQ(dm2_0, dim_metadata[4]);
|
||||
EXPECT_EQ(dm2_1, dim_metadata[5]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 9, 8, 5, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -270,7 +270,7 @@ TEST(FormatConverterTest, 3DTestD0D1S2) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0 = {3};
|
||||
const std::vector<int> dm1 = {2};
|
||||
const std::vector<int> dm2_0 = {0, 1, 3, 3, 3, 4, 5};
|
||||
@ -281,12 +281,12 @@ TEST(FormatConverterTest, 3DTestD0D1S2) {
|
||||
EXPECT_EQ(dm2_0, dim_metadata[4]);
|
||||
EXPECT_EQ(dm2_1, dim_metadata[5]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {6, 9, 8, 5, 7};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -300,7 +300,7 @@ TEST(FormatConverterTest, 3DTestS0S1S2) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 2};
|
||||
const std::vector<int> dm0_1 = {0, 2};
|
||||
const std::vector<int> dm1_0 = {0, 2, 5};
|
||||
@ -314,12 +314,12 @@ TEST(FormatConverterTest, 3DTestS0S1S2) {
|
||||
EXPECT_EQ(dm2_0, dim_metadata[4]);
|
||||
EXPECT_EQ(dm2_1, dim_metadata[5]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 7, 5, 2, 4, 8, 3, 9};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -333,7 +333,7 @@ TEST(FormatConverterTest, 3DTestS0S2S1) {
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0_0 = {0, 2};
|
||||
const std::vector<int> dm0_1 = {0, 2};
|
||||
const std::vector<int> dm1_0 = {0, 2, 5};
|
||||
@ -347,12 +347,12 @@ TEST(FormatConverterTest, 3DTestS0S2S1) {
|
||||
EXPECT_EQ(dm2_0, dim_metadata[4]);
|
||||
EXPECT_EQ(dm2_1, dim_metadata[5]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 7, 5, 2, 4, 8, 3, 9};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -369,25 +369,58 @@ TEST(FormatConverterTest, BlockTestD0D1) {
|
||||
block_size, block_map);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm = {2};
|
||||
EXPECT_EQ(dm, dim_metadata[0]);
|
||||
EXPECT_EQ(dm, dim_metadata[2]);
|
||||
EXPECT_EQ(dm, dim_metadata[4]);
|
||||
EXPECT_EQ(dm, dim_metadata[6]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 0, 0, 4, 2, 3, 0, 0,
|
||||
0, 0, 0, 0, 5, 0, 0, 6};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
// BCSR
|
||||
TEST(FormatConverterTest, BlockTestD0S1) {
|
||||
TEST(FormatConverterTest, BlockTestD0S11DBlock) {
|
||||
const std::vector<int> dense_values = {1, 0, 2, 3, 0, 4, 0, 0,
|
||||
0, 0, 5, 0, 0, 0, 0, 6};
|
||||
const std::vector<int> dense_shape = {4, 4};
|
||||
const std::vector<int> traversal_order = {0, 1, 2};
|
||||
const std::vector<TfLiteDimensionType> format = {kTfLiteDimDense,
|
||||
kTfLiteDimSparseCSR};
|
||||
const std::vector<int> block_size = {2};
|
||||
const std::vector<int> block_map = {1};
|
||||
FormatConverter<int> converter(dense_shape, traversal_order, format,
|
||||
block_size, block_map);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm0 = {4};
|
||||
const std::vector<int> dm2 = {2};
|
||||
const std::vector<int> dm1_0 = {0, 2, 3, 4, 5};
|
||||
const std::vector<int> dm1_1 = {0, 1, 0, 1, 1};
|
||||
EXPECT_EQ(dm0, dim_metadata[0]);
|
||||
EXPECT_EQ(dm1_0, dim_metadata[2]);
|
||||
EXPECT_EQ(dm1_1, dim_metadata[3]);
|
||||
EXPECT_EQ(dm2, dim_metadata[4]);
|
||||
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 0, 2, 3, 0, 4, 5, 0, 0, 6};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
// BCSR
|
||||
TEST(FormatConverterTest, BlockTestD0S12DBlock) {
|
||||
const std::vector<int> dense_values = {1, 0, 2, 3, 0, 4, 0, 0,
|
||||
0, 0, 5, 0, 0, 0, 0, 6};
|
||||
const std::vector<int> dense_shape = {4, 4};
|
||||
@ -400,7 +433,7 @@ TEST(FormatConverterTest, BlockTestD0S1) {
|
||||
block_size, block_map);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm = {2};
|
||||
const std::vector<int> dm1_0 = {0, 2, 3};
|
||||
const std::vector<int> dm1_1 = {0, 1, 1};
|
||||
@ -410,12 +443,12 @@ TEST(FormatConverterTest, BlockTestD0S1) {
|
||||
EXPECT_EQ(dm, dim_metadata[4]);
|
||||
EXPECT_EQ(dm, dim_metadata[6]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -433,7 +466,7 @@ TEST(FormatConverterTest, BlockTestD1S0) {
|
||||
block_size, block_map);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm = {2};
|
||||
const std::vector<int> dm1_0 = {0, 1, 3};
|
||||
const std::vector<int> dm1_1 = {0, 0, 1};
|
||||
@ -443,12 +476,12 @@ TEST(FormatConverterTest, BlockTestD1S0) {
|
||||
EXPECT_EQ(dm, dim_metadata[4]);
|
||||
EXPECT_EQ(dm, dim_metadata[6]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 0, 0, 4, 2, 0, 3, 0, 5, 0, 0, 6};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -466,7 +499,7 @@ TEST(FormatConverterTest, BlockTestD0S1LastBlockEmpty) {
|
||||
block_size, block_map);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm = {2};
|
||||
const std::vector<int> dm1_0 = {0, 2, 2};
|
||||
const std::vector<int> dm1_1 = {0, 1};
|
||||
@ -476,12 +509,12 @@ TEST(FormatConverterTest, BlockTestD0S1LastBlockEmpty) {
|
||||
EXPECT_EQ(dm, dim_metadata[4]);
|
||||
EXPECT_EQ(dm, dim_metadata[6]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 0, 0, 4, 2, 3, 0, 0};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
|
||||
@ -499,7 +532,7 @@ TEST(FormatConverterTest, BlockTestD0S1ColMajorBlock) {
|
||||
block_size, block_map);
|
||||
converter.DenseToSparse(dense_values.data());
|
||||
|
||||
const auto& dim_metadata = converter.GetDimMetadata();
|
||||
const auto dim_metadata = converter.GetDimMetadata();
|
||||
const std::vector<int> dm = {2};
|
||||
const std::vector<int> dm1_0 = {0, 3, 4};
|
||||
const std::vector<int> dm1_1 = {0, 1, 2, 1};
|
||||
@ -509,13 +542,13 @@ TEST(FormatConverterTest, BlockTestD0S1ColMajorBlock) {
|
||||
EXPECT_EQ(dm, dim_metadata[4]);
|
||||
EXPECT_EQ(dm, dim_metadata[6]);
|
||||
|
||||
const auto& data = converter.GetData();
|
||||
const auto data = converter.GetData();
|
||||
const std::vector<int> expected_data = {1, 1, 0, 0, 2, 2, 3, 3,
|
||||
0, 0, 4, 4, 5, 0, 0, 0};
|
||||
EXPECT_EQ(expected_data, data);
|
||||
|
||||
converter.SparseToDense(expected_data.data());
|
||||
const auto& data_back = converter.GetData();
|
||||
const auto data_back = converter.GetData();
|
||||
EXPECT_EQ(data_back, dense_values);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -461,6 +461,9 @@ class ControlFlowTransformer(converter.Base):
|
||||
loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
|
||||
|
||||
opts = self._create_loop_options(node)
|
||||
opts.keys.append(gast.Constant('iterate_names', kind=None))
|
||||
opts.values.append(gast.Constant(
|
||||
parser.unparse(node.target, include_encoding_marker=False), kind=None))
|
||||
|
||||
if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
|
||||
extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
|
||||
|
@ -0,0 +1,862 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "-vLwpT31YOJk"
|
||||
},
|
||||
"source": [
|
||||
"TODO(b/138297412): This colab retains some useful code snippets and demonstrations that used to be in the tf.function/AutoGraph customization tutorial, and should be rolled into the existing docs as part of a broader markdown-\u003ecolab conversion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "otIdN1TS8N7S"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "I0xDjO4SHLUD"
|
||||
},
|
||||
"source": [
|
||||
"Define a helper function to demonstrate the kinds of errors you might encounter:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "D25apou9IOXa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import traceback\n",
|
||||
"import contextlib\n",
|
||||
"\n",
|
||||
"# Some helper code to demonstrate the kinds of errors you might encounter.\n",
|
||||
"@contextlib.contextmanager\n",
|
||||
"def assert_raises(error_class):\n",
|
||||
" try:\n",
|
||||
" yield\n",
|
||||
" except error_class as e:\n",
|
||||
" print('Caught expected exception \\n {}:'.format(error_class))\n",
|
||||
" traceback.print_exc(limit=2)\n",
|
||||
" except Exception as e:\n",
|
||||
" raise e\n",
|
||||
" else:\n",
|
||||
" raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
|
||||
" error_class))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "5f05Vr_YBUCz"
|
||||
},
|
||||
"source": [
|
||||
"## Using AutoGraph\n",
|
||||
"\n",
|
||||
"The [autograph](https://www.tensorflow.org/guide/function) library is fully integrated with `tf.function`, and it will rewrite conditionals and loops which depend on Tensors to run dynamically in the graph.\n",
|
||||
"\n",
|
||||
"`tf.cond` and `tf.while_loop` continue to work with `tf.function`, but code with control flow is often easier to write and understand when written in imperative style."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "xgKmkrNTZSyz"
|
||||
},
|
||||
"source": [
|
||||
"## AutoGraph: Conditionals\n",
|
||||
"\n",
|
||||
"AutoGraph will convert `if` statements into the equivalent `tf.cond` calls.\n",
|
||||
"\n",
|
||||
"This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "20WlM9T2I9EV"
|
||||
},
|
||||
"source": [
|
||||
"Here is a function that checks if the resulting graph uses `tf.cond`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "E-7KllizZYsy"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_tf_cond(f, *args):\n",
|
||||
" g = f.get_concrete_function(*args).graph\n",
|
||||
" if any(node.name == 'cond' for node in g.as_graph_def().node):\n",
|
||||
" print(\"{}({}) uses tf.cond.\".format(\n",
|
||||
" f.__name__, ', '.join(map(str, args))))\n",
|
||||
" else:\n",
|
||||
" print(\"{}({}) executes normally.\".format(\n",
|
||||
" f.__name__, ', '.join(map(str, args))))\n",
|
||||
"\n",
|
||||
" print(\" result: \",f(*args).numpy())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "DlqiutEEJHOe"
|
||||
},
|
||||
"source": [
|
||||
"This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing.\n",
|
||||
"\n",
|
||||
"Passing a python `True` executes the conditional normally:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "fCMywOXwJLIQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def dropout(x, training=True):\n",
|
||||
" if training:\n",
|
||||
" x = tf.nn.dropout(x, rate=0.5)\n",
|
||||
" return x"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "68D2RZ17JM8u"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "WEz0QYucJPBa"
|
||||
},
|
||||
"source": [
|
||||
"But passing a tensor replaces the python `if` with a `tf.cond`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "o86paGR-Zadi"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "5xFLfdApZh8q"
|
||||
},
|
||||
"source": [
|
||||
"`tf.cond` has a number of subtleties.\n",
|
||||
"\n",
|
||||
"it works by tracing both sides of the conditional, and then choosing the appropriate branch at runtime, depending on the condition. Tracing both sides can result in unexpected execution of Python code."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "VTMoZEVaZiwk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def f(x):\n",
|
||||
" if x \u003e 0:\n",
|
||||
" x = x + 1.\n",
|
||||
" print(\"Tracing `then` branch\")\n",
|
||||
" else:\n",
|
||||
" x = x - 1.\n",
|
||||
" print(\"Tracing `else` branch\")\n",
|
||||
" return x"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "HqBVIZWb0Qzn"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f(-1.0).numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "BIMfbXlW0QdP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f(1.0).numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "2nBnJ42v0Pvq"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f(tf.constant(1.0)).numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "zyzzvtN5Jfpb"
|
||||
},
|
||||
"source": [
|
||||
"It requires that if one branch creates a tensor used downstream, the other branch must also create that tensor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "k_dxWHeFZlaQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def f():\n",
|
||||
" if tf.constant(True):\n",
|
||||
" x = tf.ones([3, 3])\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Throws an error because both branches need to define `x`.\n",
|
||||
"with assert_raises(ValueError):\n",
|
||||
" f()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "wP-LZP6cztnu"
|
||||
},
|
||||
"source": [
|
||||
"If you want to be sure that a particular section of control flow is never converted by autograph, then explicitly convert the object to a python type so an error is raised instead: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "iG_VDavjzrzV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def f(x, y):\n",
|
||||
" if bool(x):\n",
|
||||
" y = y + 1.\n",
|
||||
" print(\"Tracing `then` branch\")\n",
|
||||
" else:\n",
|
||||
" y = y - 1.\n",
|
||||
" print(\"Tracing `else` branch\")\n",
|
||||
" return y"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "kQ4CRP9T0rH2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f(True, 0).numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ww9tCzHy0rkv"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"f(False, 0).numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ppuV7iug0r7i"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with assert_raises(TypeError):\n",
|
||||
" f(tf.constant(True), 0.0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "yho4J0a0ZkQS"
|
||||
},
|
||||
"source": [
|
||||
"## AutoGraph and loops\n",
|
||||
"\n",
|
||||
"AutoGraph has a few simple rules for converting loops.\n",
|
||||
"\n",
|
||||
"- `for`: Convert if the iterable is a tensor\n",
|
||||
"- `while`: Convert if the while condition depends on a tensor\n",
|
||||
"\n",
|
||||
"If a loop is converted, it will be dynamically unrolled with `tf.while_loop`, or in the special case of a `for x in tf.data.Dataset`, transformed into `tf.data.Dataset.reduce`.\n",
|
||||
"\n",
|
||||
"If a loop is _not_ converted, it will be statically unrolled "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "OyzGNQAuZsky"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def test_dynamically_unrolled(f, *args):\n",
|
||||
" g = f.get_concrete_function(*args).graph\n",
|
||||
" if any(node.name == 'while' for node in g.as_graph_def().node):\n",
|
||||
" print(\"{}({}) uses tf.while_loop.\".format(\n",
|
||||
" f.__name__, ', '.join(map(str, args))))\n",
|
||||
" elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):\n",
|
||||
" print(\"{}({}) uses tf.data.Dataset.reduce.\".format(\n",
|
||||
" f.__name__, ', '.join(map(str, args))))\n",
|
||||
" else:\n",
|
||||
" print(\"{}({}) gets unrolled.\".format(\n",
|
||||
" f.__name__, ', '.join(map(str, args))))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "KFO1BSN9JkRP"
|
||||
},
|
||||
"source": [
|
||||
"### For loops\n",
|
||||
"\n",
|
||||
"Here is a `tf.function` that demonstrates static unrolling:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "frecgTco_00V"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def for_in_range():\n",
|
||||
" x = 0\n",
|
||||
" for i in range(5):\n",
|
||||
" x += i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(for_in_range)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "PMdl0azc_5d4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def for_in_tfrange():\n",
|
||||
" x = tf.constant(0, dtype=tf.int32)\n",
|
||||
" for i in tf.range(5):\n",
|
||||
" x += i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(for_in_tfrange)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Q7tmncQTZt6_"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def for_in_tfdataset():\n",
|
||||
" x = tf.constant(0, dtype=tf.int64)\n",
|
||||
" for i in tf.data.Dataset.range(5):\n",
|
||||
" x += i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(for_in_tfdataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "eyPzDYiJAC8f"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def while_py_cond():\n",
|
||||
" x = 5\n",
|
||||
" while x \u003e 0:\n",
|
||||
" x -= 1\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(while_py_cond)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "l6s7aU-padY5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def while_tf_cond():\n",
|
||||
" x = tf.constant(5)\n",
|
||||
" while x \u003e 0:\n",
|
||||
" x -= 1\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(while_tf_cond)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "dSr64Xn6ap-S"
|
||||
},
|
||||
"source": [
|
||||
" If you have a `break` or early `return` clause that depends on a tensor, the top-level condition or iterable should also be a tensor.\n",
|
||||
"\n",
|
||||
"Compare the following examples:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "hG2Fe_OEAwpY"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def while_py_true_py_break(x):\n",
|
||||
" while True: # py true\n",
|
||||
" if x == 0: # py break\n",
|
||||
" break\n",
|
||||
" x -= 1\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(while_py_true_py_break, 5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Sr2cn5bY_E_9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def buggy_while_py_true_tf_break(x):\n",
|
||||
" while True: # py true\n",
|
||||
" if tf.equal(x, 0): # tf break\n",
|
||||
" break\n",
|
||||
" x -= 1\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"with assert_raises(TypeError):\n",
|
||||
" test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Q-VirD-5avdZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def while_tf_true_tf_break(x):\n",
|
||||
" while tf.constant(True): # tf true\n",
|
||||
" if x == 0: # py break\n",
|
||||
" break\n",
|
||||
" x -= 1\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(while_tf_true_tf_break, 5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Upx5J0j8_Ldu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def buggy_py_for_tf_break():\n",
|
||||
" x = 0\n",
|
||||
" for i in range(5): # py for\n",
|
||||
" if tf.equal(i, 3): # tf break\n",
|
||||
" break\n",
|
||||
" x += i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"with assert_raises(TypeError):\n",
|
||||
" test_dynamically_unrolled(buggy_py_for_tf_break)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "GQHbodav_QMt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def tf_for_py_break():\n",
|
||||
" x = 0\n",
|
||||
" for i in tf.range(5): # tf for\n",
|
||||
" if i == 3: # py break\n",
|
||||
" break\n",
|
||||
" x += i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"test_dynamically_unrolled(tf_for_py_break)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "hyksHW9TCukR"
|
||||
},
|
||||
"source": [
|
||||
"In order to accumulate results from a dynamically unrolled loop, you'll want to use `tf.TensorArray`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "HJ3Vb3dXfefN"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size = 2\n",
|
||||
"seq_len = 3\n",
|
||||
"feature_size = 4\n",
|
||||
"\n",
|
||||
"def rnn_step(inp, state):\n",
|
||||
" return inp + state\n",
|
||||
"\n",
|
||||
"@tf.function\n",
|
||||
"def dynamic_rnn(rnn_step, input_data, initial_state):\n",
|
||||
" # [batch, time, features] -\u003e [time, batch, features]\n",
|
||||
" input_data = tf.transpose(input_data, [1, 0, 2])\n",
|
||||
" max_seq_len = input_data.shape[0]\n",
|
||||
"\n",
|
||||
" states = tf.TensorArray(tf.float32, size=max_seq_len)\n",
|
||||
" state = initial_state\n",
|
||||
" for i in tf.range(max_seq_len):\n",
|
||||
" state = rnn_step(input_data[i], state)\n",
|
||||
" states = states.write(i, state)\n",
|
||||
" return tf.transpose(states.stack(), [1, 0, 2])\n",
|
||||
" \n",
|
||||
"dynamic_rnn(rnn_step,\n",
|
||||
" tf.random.uniform([batch_size, seq_len, feature_size]),\n",
|
||||
" tf.zeros([batch_size, feature_size]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "9gmLpHY-bkly"
|
||||
},
|
||||
"source": [
|
||||
"### Gotcha's\n",
|
||||
"\n",
|
||||
"As with `tf.cond`, `tf.while_loop` also comes with a number of subtleties.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "FJdfznhhKO7D"
|
||||
},
|
||||
"source": [
|
||||
"#### Zero iterations\n",
|
||||
"\n",
|
||||
"Since a loop can execute 0 times, all tensors used downstream of the while_loop must be initialized above the loop.\n",
|
||||
"\n",
|
||||
"Here is an example of incorrect code:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "CocT5RHwblrQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def buggy_loop_var_uninitialized():\n",
|
||||
" for i in tf.range(3):\n",
|
||||
" x = i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"with assert_raises(ValueError):\n",
|
||||
" buggy_loop_var_uninitialized()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "ncr7tRZ1KWh9"
|
||||
},
|
||||
"source": [
|
||||
"And the correct version:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Wm7wIKXcCDGf"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def f():\n",
|
||||
" x = tf.constant(0)\n",
|
||||
" for i in tf.range(3):\n",
|
||||
" x = i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"f()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "CM7qXVY0KZHB"
|
||||
},
|
||||
"source": [
|
||||
"#### Consistent shapes and types\n",
|
||||
"\n",
|
||||
"The shape/dtypes of all loop variables must stay consistent with each iteration.\n",
|
||||
"\n",
|
||||
"Here is an incorrect example that attempts to change a tensor's type:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "FSftc9cCbpAo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def buggy_loop_type_changes():\n",
|
||||
" x = tf.constant(0, dtype=tf.float32)\n",
|
||||
" for i in tf.range(3): # Yields tensors of type tf.int32...\n",
|
||||
" x = i\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"with assert_raises(TypeError):\n",
|
||||
" buggy_loop_type_changes()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "M5l90NAHKsUM"
|
||||
},
|
||||
"source": [
|
||||
"Here is an incorrect example that attempts to change a Tensor's shape while iterating:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "kWF189prbuK0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def buggy_concat():\n",
|
||||
" x = tf.ones([0, 10])\n",
|
||||
" for i in tf.range(5):\n",
|
||||
" x = tf.concat([x, tf.ones([1, 10])], axis=0)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"with assert_raises(ValueError):\n",
|
||||
" buggy_concat()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "miYnYcznCHeV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def concat_with_padding():\n",
|
||||
" x = tf.zeros([5, 10])\n",
|
||||
" for i in tf.range(5):\n",
|
||||
" x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)\n",
|
||||
" x.set_shape([5, 10])\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"concat_with_padding()\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "performance.ipynb",
|
||||
"private_outputs": true,
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@ -420,6 +420,21 @@ def extra_test(break_):
|
||||
break_, = ag__.for_stmt(range(10), extra_test, ..., (break_,))
|
||||
```
|
||||
|
||||
Mixing Tensor-dependent `break` and Python-dependent loops is disallowed:
|
||||
|
||||
```
|
||||
@tf.function
|
||||
def buggy_while_py_true_tf_break(x):
|
||||
while True: # python conditional
|
||||
if tf.equal(x, 0): # tensor break
|
||||
break
|
||||
x -= 1
|
||||
return x
|
||||
|
||||
# Raises OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed
|
||||
# buggy_while_true_tf_break(5)
|
||||
```
|
||||
|
||||
### `continue` statements
|
||||
|
||||
Code blocks in which `continue` statements are used are rewritten with
|
||||
|
@ -62,7 +62,7 @@ Adding a call to `tf.config.experimental_execute_functions_eagerly` before
|
||||
executing the function will land the debugger in the original code instead:
|
||||
|
||||
```
|
||||
tf.config.experimental_run_functions_eagerly(True)
|
||||
tf.config.run_functions_eagerly(True)
|
||||
f(1)
|
||||
```
|
||||
|
||||
|
@ -501,8 +501,18 @@ def _tf_range_for_stmt(
|
||||
|
||||
iterate = compat_util.BasicRef(start)
|
||||
|
||||
def _value_or(name, var, default):
|
||||
if (name == opts['iterate_names']
|
||||
and isinstance(var, special_values.Undefined)):
|
||||
return default
|
||||
return var
|
||||
|
||||
def aug_get_state():
|
||||
return (iterate.value,) + get_state()
|
||||
state_vars = get_state()
|
||||
state_vars = tuple(
|
||||
_value_or(name, var, iterate.value)
|
||||
for name, var in zip(symbol_names, state_vars))
|
||||
return (iterate.value,) + state_vars
|
||||
|
||||
def aug_set_state(aug_loop_vars):
|
||||
# TOOD(mdan): Use starred assignment once we can switch to Py3-only syntax.
|
||||
@ -876,16 +886,19 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts):
|
||||
init_vars, loop_vars, new_loop_vars, symbol_names, opts)
|
||||
return new_loop_vars
|
||||
|
||||
# Non-v2 while_loop unpacks the results when there is only one return value.
|
||||
# This enforces consistency across versions.
|
||||
opts['return_same_structure'] = True
|
||||
|
||||
if 'shape_invariants' in opts:
|
||||
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
|
||||
opts['shape_invariants'], init_vars)
|
||||
|
||||
while_loop_opts = dict(opts)
|
||||
while_loop_opts.pop('iterate_names', None)
|
||||
|
||||
# Non-v2 while_loop unpacks the results when there is only one return value.
|
||||
# This enforces consistency across versions.
|
||||
while_loop_opts['return_same_structure'] = True
|
||||
|
||||
final_loop_vars = control_flow_ops.while_loop(
|
||||
aug_test, aug_body, init_vars, **opts)
|
||||
aug_test, aug_body, init_vars, **while_loop_opts)
|
||||
set_state(final_loop_vars)
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ class ForLoopTest(test.TestCase):
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor_explicit_limit_delta(self):
|
||||
@ -109,7 +109,7 @@ class ForLoopTest(test.TestCase):
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (-171207,))
|
||||
|
||||
def test_range_tensor_explicit_limit_negative_delta(self):
|
||||
@ -129,7 +129,7 @@ class ForLoopTest(test.TestCase):
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
|
||||
def test_range_tensor_random_delta(self):
|
||||
@ -150,7 +150,7 @@ class ForLoopTest(test.TestCase):
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (1234,))
|
||||
|
||||
def test_range_tensor_random_negative_delta(self):
|
||||
@ -171,7 +171,7 @@ class ForLoopTest(test.TestCase):
|
||||
get_state=lambda: (s,),
|
||||
set_state=set_state,
|
||||
symbol_names=('s',),
|
||||
opts={})
|
||||
opts={'iterate_names': 'i'})
|
||||
self.assertEqual(self.evaluate(s), (171207,))
|
||||
|
||||
def test_tensor_with_extra_test_object_vars(self):
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 4, 29)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 4, 30)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
@ -19,11 +19,9 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import testing
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import test
|
||||
@ -34,11 +32,7 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase,
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testShuffleAndRepeatFusion(self):
|
||||
if tf2.enabled() and context.executing_eagerly():
|
||||
expected = "Shuffle"
|
||||
else:
|
||||
expected = "ShuffleAndRepeat"
|
||||
|
||||
expected = "ShuffleAndRepeat"
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
testing.assert_next([expected])).shuffle(10).repeat(2)
|
||||
options = dataset_ops.Options()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user