Add a debug-only pass that introduces a small error to a designated TF node
This can let us check how susceptible a model or a unit test is to floating point differences. PiperOrigin-RevId: 240824222
This commit is contained in:
parent
36611db67d
commit
ef5519ab43
@ -237,6 +237,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -497,6 +498,7 @@ cc_library(
|
||||
"encapsulate_xla_computations_pass.cc",
|
||||
"extract_outside_compilation_pass.cc",
|
||||
"increase_dynamism_for_auto_jit_pass.cc",
|
||||
"introduce_floating_point_jitter_pass.cc",
|
||||
"mark_for_compilation_pass.cc",
|
||||
"mark_for_compilation_pass_test_helper.cc",
|
||||
"partially_decluster_pass.cc",
|
||||
@ -509,6 +511,7 @@ cc_library(
|
||||
"encapsulate_xla_computations_pass.h",
|
||||
"extract_outside_compilation_pass.h",
|
||||
"increase_dynamism_for_auto_jit_pass.h",
|
||||
"introduce_floating_point_jitter_pass.h",
|
||||
"mark_for_compilation_pass.h",
|
||||
"mark_for_compilation_pass_test_helper.h",
|
||||
"partially_decluster_pass.h",
|
||||
@ -523,6 +526,7 @@ cc_library(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:scope_internal",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
@ -551,6 +555,7 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -636,6 +641,8 @@ tf_cc_test(
|
||||
"encapsulate_xla_computations_pass_test.cc",
|
||||
"extract_outside_compilation_pass_test.cc",
|
||||
"increase_dynamism_for_auto_jit_pass_test.cc",
|
||||
"introduce_floating_point_jitter_pass_internal.h",
|
||||
"introduce_floating_point_jitter_pass_test.cc",
|
||||
"mark_for_compilation_pass_test.cc",
|
||||
"partially_decluster_pass_test.cc",
|
||||
],
|
||||
@ -677,6 +684,7 @@ tf_cc_test(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
@ -26,6 +27,7 @@ BuildXlaOpsPassFlags* build_ops_flags;
|
||||
MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||
XlaDeviceFlags* device_flags;
|
||||
XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
std::once_flag flags_init;
|
||||
@ -86,21 +88,38 @@ void AllocateAndParseFlags() {
|
||||
ops_flags = new XlaOpsCommonFlags;
|
||||
ops_flags->tf_xla_always_defer_compilation = false;
|
||||
|
||||
flag_list = new std::vector<Flag>({
|
||||
Flag("tf_xla_enable_lazy_compilation",
|
||||
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
|
||||
Flag("tf_xla_print_cluster_outputs",
|
||||
&build_ops_flags->tf_xla_print_cluster_outputs,
|
||||
"If true then insert Print nodes to print out values produced by "
|
||||
"XLA clusters."),
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
"autoclustering ops are compiled one by one just-in-time."),
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
return true;
|
||||
};
|
||||
|
||||
flag_list = new std::vector<Flag>(
|
||||
{Flag("tf_xla_enable_lazy_compilation",
|
||||
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
|
||||
Flag("tf_xla_print_cluster_outputs",
|
||||
&build_ops_flags->tf_xla_print_cluster_outputs,
|
||||
"If true then insert Print nodes to print out values produced by "
|
||||
"XLA clusters."),
|
||||
|
||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
"autoclustering ops are compiled one by one just-in-time."),
|
||||
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||
|
||||
Flag("tf_introduce_floating_point_jitter_to_tensors",
|
||||
setter_for_jitter_tensor_names, "",
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names."),
|
||||
Flag("tf_introduce_floating_point_jitter_amount",
|
||||
&jitter_flags->jitter_amount,
|
||||
"The Tensors to add the jitter to. The tensors are named in the "
|
||||
"TensorId format of <node name>:<output idx>.")});
|
||||
|
||||
Flag("tf_xla_always_defer_compilation",
|
||||
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||
});
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
}
|
||||
@ -127,9 +146,14 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
|
||||
return *ops_flags;
|
||||
}
|
||||
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -82,6 +82,17 @@ struct BuildXlaOpsPassFlags {
|
||||
bool tf_xla_print_cluster_outputs;
|
||||
};
|
||||
|
||||
// Flags for the IntroduceFloatingPointJitter pass.
|
||||
struct IntroduceFloatingPointJitterPassFlags {
|
||||
// The amount of jitter to introduce. This amount is added to each element in
|
||||
// the tensors named in `tensor_names.
|
||||
float jitter_amount;
|
||||
|
||||
// The Tensors to add the jitter to. The tensors are named in the TensorId
|
||||
// format of <node name>:<output idx>.
|
||||
std::vector<string> tensor_names;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
@ -94,6 +105,9 @@ const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
|
||||
XlaDeviceFlags* GetXlaDeviceFlags();
|
||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
153
tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc
Normal file
153
tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
std::vector<std::pair<Node*, std::vector<int>>> GetNodesToModify(
|
||||
const Graph& g, absl::Span<const string> tensor_names) {
|
||||
absl::flat_hash_map<string, Node*> name_to_node;
|
||||
for (Node* n : g.op_nodes()) {
|
||||
name_to_node[n->name()] = n;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<Node*, std::vector<int>> nodes_to_modify_map;
|
||||
|
||||
for (const string& tensor_name : tensor_names) {
|
||||
TensorId tensor_id = ParseTensorName(tensor_name);
|
||||
auto it = name_to_node.find(tensor_id.node());
|
||||
DCHECK(it != name_to_node.end());
|
||||
nodes_to_modify_map[it->second].push_back(tensor_id.index());
|
||||
}
|
||||
|
||||
std::vector<std::pair<Node*, std::vector<int>>> nodes_to_modify;
|
||||
absl::c_copy(nodes_to_modify_map, std::back_inserter(nodes_to_modify));
|
||||
|
||||
absl::c_sort(nodes_to_modify,
|
||||
[](const std::pair<Node*, std::vector<int>>& a,
|
||||
const std::pair<Node*, std::vector<int>>& b) {
|
||||
return a.first->id() < b.first->id();
|
||||
});
|
||||
|
||||
for (auto& p : nodes_to_modify) {
|
||||
absl::c_sort(p.second);
|
||||
p.second.erase(std::unique(p.second.begin(), p.second.end()),
|
||||
p.second.end());
|
||||
}
|
||||
|
||||
return nodes_to_modify;
|
||||
}
|
||||
|
||||
Status IntroduceJitterToTensor(
|
||||
Graph* g, Node* n, int oidx, float jitter_amount,
|
||||
absl::flat_hash_map<std::pair<DataType, Node*>, Output>*
|
||||
node_to_jitter_constant) {
|
||||
std::vector<const Edge*> edges_to_update;
|
||||
absl::c_copy_if(n->out_edges(), std::back_inserter(edges_to_update),
|
||||
[&](const Edge* e) { return e->src_output() == oidx; });
|
||||
|
||||
if (edges_to_update.empty()) {
|
||||
VLOG(1) << "No users for " << TensorId(n->name(), oidx).ToString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Updating " << edges_to_update.size() << " users for "
|
||||
<< TensorId(n->name(), oidx).ToString();
|
||||
|
||||
Status status;
|
||||
Scope s = NewInternalScope(g, &status, /*refiner=*/nullptr)
|
||||
.NewSubScope(absl::StrCat(n->name(), "/jitter"));
|
||||
|
||||
Output node_out(n, oidx);
|
||||
Output jitter_constant;
|
||||
DataType dtype = n->output_type(oidx);
|
||||
auto it = node_to_jitter_constant->find({dtype, n});
|
||||
if (it == node_to_jitter_constant->end()) {
|
||||
Tensor constant_tensor;
|
||||
if (dtype == DT_FLOAT) {
|
||||
constant_tensor = Tensor(static_cast<float>(jitter_amount));
|
||||
} else if (dtype == DT_HALF) {
|
||||
constant_tensor = Tensor(Eigen::half(jitter_amount));
|
||||
} else {
|
||||
return errors::Unimplemented("Only float and half are supported");
|
||||
}
|
||||
|
||||
jitter_constant =
|
||||
ops::Const(s.WithOpName("jitter_amount"), constant_tensor);
|
||||
(*node_to_jitter_constant)[{dtype, n}] = jitter_constant;
|
||||
} else {
|
||||
jitter_constant = it->second;
|
||||
}
|
||||
|
||||
Output jittered_output =
|
||||
ops::Add(s.NewSubScope(absl::StrCat(oidx)).WithOpName("jittered_output"),
|
||||
jitter_constant, node_out);
|
||||
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
for (const Edge* e : edges_to_update) {
|
||||
VLOG(3) << "Updating " << e->dst()->name();
|
||||
TF_RETURN_IF_ERROR(
|
||||
g->UpdateEdge(jittered_output.node(), 0, e->dst(), e->dst_input()));
|
||||
}
|
||||
|
||||
// Add a control edge to make sure that the two inputs to jittered_output are
|
||||
// from the same frame.
|
||||
g->AddControlEdge(n, jitter_constant.node());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status IntroduceFloatingPointJitter(Graph* graph,
|
||||
absl::Span<string const> tensor_names,
|
||||
float jitter_amount) {
|
||||
if (tensor_names.empty()) {
|
||||
VLOG(3) << "Nothing to do";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::pair<Node*, std::vector<int>>> nodes_to_modify =
|
||||
GetNodesToModify(*graph, tensor_names);
|
||||
|
||||
absl::flat_hash_map<std::pair<DataType, Node*>, Output>
|
||||
node_to_jitter_constant;
|
||||
for (const auto& p : nodes_to_modify) {
|
||||
for (int oidx : p.second) {
|
||||
TF_RETURN_IF_ERROR(IntroduceJitterToTensor(
|
||||
graph, p.first, oidx, jitter_amount, &node_to_jitter_constant));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IntroduceFloatingPointJitterPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
const IntroduceFloatingPointJitterPassFlags& flags =
|
||||
GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
return IntroduceFloatingPointJitter(options.graph->get(), flags.tensor_names,
|
||||
flags.jitter_amount);
|
||||
}
|
||||
} // namespace tensorflow
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// A debug-only pass that introduces error into outputs of specific TF nodes.
|
||||
// This can be used to check the sensitivity of a TF graph to floating point
|
||||
// rounding differences.
|
||||
//
|
||||
// This pass is controlled by TF_XLA_FLAGS. Please see
|
||||
// IntroduceFloatingPointJitterPassFlags for information on how to use this.
|
||||
class IntroduceFloatingPointJitterPass : public GraphOptimizationPass {
|
||||
public:
|
||||
IntroduceFloatingPointJitterPass() = default;
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_
|
@ -0,0 +1,27 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
Status IntroduceFloatingPointJitter(Graph* graph,
|
||||
absl::Span<string const> tensor_names,
|
||||
float jitter_amount);
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_
|
@ -0,0 +1,197 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/linalg_ops.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/compiler/jit/node_matchers.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using testing::matchers::Const;
|
||||
using testing::matchers::Inputs;
|
||||
using testing::matchers::Name;
|
||||
using testing::matchers::NodeWith;
|
||||
using testing::matchers::Op;
|
||||
using testing::matchers::Out;
|
||||
|
||||
TEST(IntroduceFloatingPointJitterTest, SingleOutputFP32) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
|
||||
Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
|
||||
|
||||
Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
|
||||
Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
|
||||
|
||||
Output tanh_a = ops::Tanh(root.WithOpName("tanh_a"), sigmoid_a);
|
||||
Output tanh_b = ops::Tanh(root.WithOpName("tanh_b"), sigmoid_b);
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
std::vector<string> tensor_names;
|
||||
tensor_names.push_back("sigmoid_a");
|
||||
tensor_names.push_back("sigmoid_b");
|
||||
|
||||
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
|
||||
VLOG(1) << graph->ToGraphDefDebug().DebugString();
|
||||
|
||||
auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
|
||||
auto m_sigmoid_a_with_jitter =
|
||||
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
|
||||
auto m_tanh_a = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_a_with_jitter)));
|
||||
|
||||
auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
|
||||
auto m_sigmoid_b_with_jitter =
|
||||
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
|
||||
auto m_tanh_b = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_b_with_jitter)));
|
||||
|
||||
Node* tanh_a_transformed = testing::FindNodeByName(graph.get(), "tanh_a");
|
||||
Node* tanh_b_transformed = testing::FindNodeByName(graph.get(), "tanh_b");
|
||||
|
||||
ASSERT_NE(tanh_a_transformed, nullptr);
|
||||
ASSERT_NE(tanh_b_transformed, nullptr);
|
||||
|
||||
EXPECT_THAT(tanh_a_transformed, m_tanh_a);
|
||||
EXPECT_THAT(tanh_b_transformed, m_tanh_b);
|
||||
}
|
||||
|
||||
TEST(IntroduceFloatingPointJitterTest, TwoNodesOneUser) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
|
||||
Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
|
||||
|
||||
Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
|
||||
Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), sigmoid_a, sigmoid_b);
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
std::vector<string> tensor_names;
|
||||
tensor_names.push_back("sigmoid_a");
|
||||
tensor_names.push_back("sigmoid_b");
|
||||
|
||||
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
|
||||
VLOG(1) << graph->ToGraphDefDebug().DebugString();
|
||||
|
||||
auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
|
||||
auto m_sigmoid_a_with_jitter =
|
||||
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
|
||||
|
||||
auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
|
||||
auto m_sigmoid_b_with_jitter =
|
||||
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
|
||||
|
||||
auto m_add = NodeWith(Op("Add"), Inputs(Out(m_sigmoid_a_with_jitter),
|
||||
Out(m_sigmoid_b_with_jitter)));
|
||||
|
||||
Node* add_transformed = testing::FindNodeByName(graph.get(), "add");
|
||||
|
||||
ASSERT_NE(add_transformed, nullptr);
|
||||
|
||||
EXPECT_THAT(add_transformed, m_add);
|
||||
}
|
||||
|
||||
TEST(IntroduceFloatingPointJitterTest, NotFP32) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
|
||||
|
||||
Output sigmoid = ops::Sigmoid(root.WithOpName("sigmoid"), input);
|
||||
|
||||
Output tanh = ops::Tanh(root.WithOpName("tanh"), sigmoid);
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
std::vector<string> tensor_names;
|
||||
tensor_names.push_back("sigmoid");
|
||||
|
||||
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
|
||||
VLOG(1) << graph->ToGraphDefDebug().DebugString();
|
||||
|
||||
auto m_sigmoid = Out(NodeWith(Name("sigmoid")));
|
||||
auto m_sigmoid_with_jitter =
|
||||
NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_sigmoid));
|
||||
auto m_tanh = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_with_jitter)));
|
||||
|
||||
Node* tanh_transformed = testing::FindNodeByName(graph.get(), "tanh");
|
||||
|
||||
ASSERT_NE(tanh_transformed, nullptr);
|
||||
|
||||
EXPECT_THAT(tanh_transformed, m_tanh);
|
||||
}
|
||||
|
||||
TEST(IntroduceFloatingPointJitterTest, MultiOutput) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
|
||||
|
||||
ops::Svd svd(root.WithOpName("svd"), input);
|
||||
|
||||
Output tanh_s = ops::Tanh(root.WithOpName("tanh_s"), svd.s);
|
||||
Output tanh_u = ops::Tanh(root.WithOpName("tanh_u"), svd.u);
|
||||
Output tanh_v = ops::Tanh(root.WithOpName("tanh_v"), svd.v);
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
std::vector<string> tensor_names;
|
||||
tensor_names.push_back("svd:0");
|
||||
tensor_names.push_back("svd:2");
|
||||
|
||||
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
|
||||
VLOG(1) << graph->ToGraphDefDebug().DebugString();
|
||||
|
||||
auto m_svd_s = Out(0, NodeWith(Name("svd")));
|
||||
auto m_svd_s_with_jitter = Out(
|
||||
NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_s)));
|
||||
|
||||
auto m_svd_u = Out(1, NodeWith(Name("svd")));
|
||||
|
||||
auto m_svd_v = Out(2, NodeWith(Name("svd")));
|
||||
auto m_svd_v_with_jitter = Out(
|
||||
NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_v)));
|
||||
|
||||
auto m_tanh_s = NodeWith(Op("Tanh"), Inputs(m_svd_s_with_jitter));
|
||||
auto m_tanh_u = NodeWith(Op("Tanh"), Inputs(m_svd_u));
|
||||
auto m_tanh_v = NodeWith(Op("Tanh"), Inputs(m_svd_v_with_jitter));
|
||||
|
||||
Node* tanh_s_transformed = testing::FindNodeByName(graph.get(), "tanh_s");
|
||||
ASSERT_NE(tanh_s_transformed, nullptr);
|
||||
|
||||
Node* tanh_u_transformed = testing::FindNodeByName(graph.get(), "tanh_u");
|
||||
ASSERT_NE(tanh_u_transformed, nullptr);
|
||||
|
||||
Node* tanh_v_transformed = testing::FindNodeByName(graph.get(), "tanh_v");
|
||||
ASSERT_NE(tanh_v_transformed, nullptr);
|
||||
|
||||
EXPECT_THAT(tanh_s_transformed, m_tanh_s);
|
||||
EXPECT_THAT(tanh_u_transformed, m_tanh_u);
|
||||
EXPECT_THAT(tanh_v_transformed, m_tanh_v);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
@ -31,6 +32,9 @@ namespace tensorflow {
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
|
||||
EncapsulateXlaComputationsPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
|
||||
IntroduceFloatingPointJitterPass);
|
||||
|
||||
// from
|
||||
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
|
||||
// FunctionalizeControlFlowPass: 27
|
||||
|
@ -77,6 +77,8 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
|
||||
}
|
||||
|
||||
switch (tensor.dtype()) {
|
||||
case DT_HALF:
|
||||
return CompareTensor<Eigen::half>(tensor, expected_tensor, listener);
|
||||
case DT_FLOAT:
|
||||
return CompareTensor<float>(tensor, expected_tensor, listener);
|
||||
case DT_DOUBLE:
|
||||
|
Loading…
Reference in New Issue
Block a user