Add a debug-only pass that introduces a small error to a designated TF node

This can let us check how susceptible a model or a unit test is to floating
point differences.

PiperOrigin-RevId: 240824222
This commit is contained in:
Sanjoy Das 2019-03-28 12:09:43 -07:00 committed by TensorFlower Gardener
parent 36611db67d
commit ef5519ab43
9 changed files with 478 additions and 14 deletions

View File

@ -237,6 +237,7 @@ cc_library(
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)
@ -497,6 +498,7 @@ cc_library(
"encapsulate_xla_computations_pass.cc",
"extract_outside_compilation_pass.cc",
"increase_dynamism_for_auto_jit_pass.cc",
"introduce_floating_point_jitter_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@ -509,6 +511,7 @@ cc_library(
"encapsulate_xla_computations_pass.h",
"extract_outside_compilation_pass.h",
"increase_dynamism_for_auto_jit_pass.h",
"introduce_floating_point_jitter_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@ -523,6 +526,7 @@ cc_library(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/ops:xla_ops",
@ -551,6 +555,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@ -636,6 +641,8 @@ tf_cc_test(
"encapsulate_xla_computations_pass_test.cc",
"extract_outside_compilation_pass_test.cc",
"increase_dynamism_for_auto_jit_pass_test.cc",
"introduce_floating_point_jitter_pass_internal.h",
"introduce_floating_point_jitter_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
@ -677,6 +684,7 @@ tf_cc_test(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <mutex> // NOLINT
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h"
@ -26,6 +27,7 @@ BuildXlaOpsPassFlags* build_ops_flags;
MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
@ -86,21 +88,38 @@ void AllocateAndParseFlags() {
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_enable_lazy_compilation",
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
Flag("tf_xla_print_cluster_outputs",
&build_ops_flags->tf_xla_print_cluster_outputs,
"If true then insert Print nodes to print out values produced by "
"XLA clusters."),
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
};
flag_list = new std::vector<Flag>(
{Flag("tf_xla_enable_lazy_compilation",
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
Flag("tf_xla_print_cluster_outputs",
&build_ops_flags->tf_xla_print_cluster_outputs,
"If true then insert Print nodes to print out values produced by "
"XLA clusters."),
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
Flag("tf_introduce_floating_point_jitter_to_tensors",
setter_for_jitter_tensor_names, "",
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names."),
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The Tensors to add the jitter to. The tensors are named in the "
"TensorId format of <node name>:<output idx>.")});
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
}
@ -127,9 +146,14 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
return *ops_flags;
}
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return *jitter_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
} // namespace tensorflow

View File

@ -82,6 +82,17 @@ struct BuildXlaOpsPassFlags {
bool tf_xla_print_cluster_outputs;
};
// Flags for the IntroduceFloatingPointJitter pass.
struct IntroduceFloatingPointJitterPassFlags {
// The amount of jitter to introduce. This amount is added to each element in
// the tensors named in `tensor_names.
float jitter_amount;
// The Tensors to add the jitter to. The tensors are named in the TensorId
// format of <node name>:<output idx>.
std::vector<string> tensor_names;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
@ -94,6 +105,9 @@ const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
XlaDeviceFlags* GetXlaDeviceFlags();
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -0,0 +1,153 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/graph/tensor_id.h"
namespace tensorflow {
namespace {
std::vector<std::pair<Node*, std::vector<int>>> GetNodesToModify(
const Graph& g, absl::Span<const string> tensor_names) {
absl::flat_hash_map<string, Node*> name_to_node;
for (Node* n : g.op_nodes()) {
name_to_node[n->name()] = n;
}
absl::flat_hash_map<Node*, std::vector<int>> nodes_to_modify_map;
for (const string& tensor_name : tensor_names) {
TensorId tensor_id = ParseTensorName(tensor_name);
auto it = name_to_node.find(tensor_id.node());
DCHECK(it != name_to_node.end());
nodes_to_modify_map[it->second].push_back(tensor_id.index());
}
std::vector<std::pair<Node*, std::vector<int>>> nodes_to_modify;
absl::c_copy(nodes_to_modify_map, std::back_inserter(nodes_to_modify));
absl::c_sort(nodes_to_modify,
[](const std::pair<Node*, std::vector<int>>& a,
const std::pair<Node*, std::vector<int>>& b) {
return a.first->id() < b.first->id();
});
for (auto& p : nodes_to_modify) {
absl::c_sort(p.second);
p.second.erase(std::unique(p.second.begin(), p.second.end()),
p.second.end());
}
return nodes_to_modify;
}
Status IntroduceJitterToTensor(
Graph* g, Node* n, int oidx, float jitter_amount,
absl::flat_hash_map<std::pair<DataType, Node*>, Output>*
node_to_jitter_constant) {
std::vector<const Edge*> edges_to_update;
absl::c_copy_if(n->out_edges(), std::back_inserter(edges_to_update),
[&](const Edge* e) { return e->src_output() == oidx; });
if (edges_to_update.empty()) {
VLOG(1) << "No users for " << TensorId(n->name(), oidx).ToString();
return Status::OK();
}
VLOG(1) << "Updating " << edges_to_update.size() << " users for "
<< TensorId(n->name(), oidx).ToString();
Status status;
Scope s = NewInternalScope(g, &status, /*refiner=*/nullptr)
.NewSubScope(absl::StrCat(n->name(), "/jitter"));
Output node_out(n, oidx);
Output jitter_constant;
DataType dtype = n->output_type(oidx);
auto it = node_to_jitter_constant->find({dtype, n});
if (it == node_to_jitter_constant->end()) {
Tensor constant_tensor;
if (dtype == DT_FLOAT) {
constant_tensor = Tensor(static_cast<float>(jitter_amount));
} else if (dtype == DT_HALF) {
constant_tensor = Tensor(Eigen::half(jitter_amount));
} else {
return errors::Unimplemented("Only float and half are supported");
}
jitter_constant =
ops::Const(s.WithOpName("jitter_amount"), constant_tensor);
(*node_to_jitter_constant)[{dtype, n}] = jitter_constant;
} else {
jitter_constant = it->second;
}
Output jittered_output =
ops::Add(s.NewSubScope(absl::StrCat(oidx)).WithOpName("jittered_output"),
jitter_constant, node_out);
TF_RETURN_IF_ERROR(status);
for (const Edge* e : edges_to_update) {
VLOG(3) << "Updating " << e->dst()->name();
TF_RETURN_IF_ERROR(
g->UpdateEdge(jittered_output.node(), 0, e->dst(), e->dst_input()));
}
// Add a control edge to make sure that the two inputs to jittered_output are
// from the same frame.
g->AddControlEdge(n, jitter_constant.node());
return Status::OK();
}
} // namespace
Status IntroduceFloatingPointJitter(Graph* graph,
absl::Span<string const> tensor_names,
float jitter_amount) {
if (tensor_names.empty()) {
VLOG(3) << "Nothing to do";
return Status::OK();
}
std::vector<std::pair<Node*, std::vector<int>>> nodes_to_modify =
GetNodesToModify(*graph, tensor_names);
absl::flat_hash_map<std::pair<DataType, Node*>, Output>
node_to_jitter_constant;
for (const auto& p : nodes_to_modify) {
for (int oidx : p.second) {
TF_RETURN_IF_ERROR(IntroduceJitterToTensor(
graph, p.first, oidx, jitter_amount, &node_to_jitter_constant));
}
}
return Status::OK();
}
Status IntroduceFloatingPointJitterPass::Run(
const GraphOptimizationPassOptions& options) {
const IntroduceFloatingPointJitterPassFlags& flags =
GetIntroduceFloatingPointJitterPassFlags();
return IntroduceFloatingPointJitter(options.graph->get(), flags.tensor_names,
flags.jitter_amount);
}
} // namespace tensorflow

View File

@ -0,0 +1,35 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_
#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// A debug-only pass that introduces error into outputs of specific TF nodes.
// This can be used to check the sensitivity of a TF graph to floating point
// rounding differences.
//
// This pass is controlled by TF_XLA_FLAGS. Please see
// IntroduceFloatingPointJitterPassFlags for information on how to use this.
class IntroduceFloatingPointJitterPass : public GraphOptimizationPass {
public:
IntroduceFloatingPointJitterPass() = default;
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_
#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_
#include "absl/types/span.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
Status IntroduceFloatingPointJitter(Graph* graph,
absl::Span<string const> tensor_names,
float jitter_amount);
}
#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_

View File

@ -0,0 +1,197 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/linalg_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using testing::matchers::Const;
using testing::matchers::Inputs;
using testing::matchers::Name;
using testing::matchers::NodeWith;
using testing::matchers::Op;
using testing::matchers::Out;
TEST(IntroduceFloatingPointJitterTest, SingleOutputFP32) {
Scope root = Scope::NewRootScope().ExitOnError();
Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
Output tanh_a = ops::Tanh(root.WithOpName("tanh_a"), sigmoid_a);
Output tanh_b = ops::Tanh(root.WithOpName("tanh_b"), sigmoid_b);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
std::vector<string> tensor_names;
tensor_names.push_back("sigmoid_a");
tensor_names.push_back("sigmoid_b");
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
VLOG(1) << graph->ToGraphDefDebug().DebugString();
auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
auto m_sigmoid_a_with_jitter =
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
auto m_tanh_a = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_a_with_jitter)));
auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
auto m_sigmoid_b_with_jitter =
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
auto m_tanh_b = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_b_with_jitter)));
Node* tanh_a_transformed = testing::FindNodeByName(graph.get(), "tanh_a");
Node* tanh_b_transformed = testing::FindNodeByName(graph.get(), "tanh_b");
ASSERT_NE(tanh_a_transformed, nullptr);
ASSERT_NE(tanh_b_transformed, nullptr);
EXPECT_THAT(tanh_a_transformed, m_tanh_a);
EXPECT_THAT(tanh_b_transformed, m_tanh_b);
}
TEST(IntroduceFloatingPointJitterTest, TwoNodesOneUser) {
Scope root = Scope::NewRootScope().ExitOnError();
Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT);
Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT);
Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a);
Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b);
Output add = ops::Add(root.WithOpName("add"), sigmoid_a, sigmoid_b);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
std::vector<string> tensor_names;
tensor_names.push_back("sigmoid_a");
tensor_names.push_back("sigmoid_b");
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
VLOG(1) << graph->ToGraphDefDebug().DebugString();
auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a")));
auto m_sigmoid_a_with_jitter =
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a));
auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b")));
auto m_sigmoid_b_with_jitter =
NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b));
auto m_add = NodeWith(Op("Add"), Inputs(Out(m_sigmoid_a_with_jitter),
Out(m_sigmoid_b_with_jitter)));
Node* add_transformed = testing::FindNodeByName(graph.get(), "add");
ASSERT_NE(add_transformed, nullptr);
EXPECT_THAT(add_transformed, m_add);
}
TEST(IntroduceFloatingPointJitterTest, NotFP32) {
Scope root = Scope::NewRootScope().ExitOnError();
Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
Output sigmoid = ops::Sigmoid(root.WithOpName("sigmoid"), input);
Output tanh = ops::Tanh(root.WithOpName("tanh"), sigmoid);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
std::vector<string> tensor_names;
tensor_names.push_back("sigmoid");
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
VLOG(1) << graph->ToGraphDefDebug().DebugString();
auto m_sigmoid = Out(NodeWith(Name("sigmoid")));
auto m_sigmoid_with_jitter =
NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_sigmoid));
auto m_tanh = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_with_jitter)));
Node* tanh_transformed = testing::FindNodeByName(graph.get(), "tanh");
ASSERT_NE(tanh_transformed, nullptr);
EXPECT_THAT(tanh_transformed, m_tanh);
}
TEST(IntroduceFloatingPointJitterTest, MultiOutput) {
Scope root = Scope::NewRootScope().ExitOnError();
Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF);
ops::Svd svd(root.WithOpName("svd"), input);
Output tanh_s = ops::Tanh(root.WithOpName("tanh_s"), svd.s);
Output tanh_u = ops::Tanh(root.WithOpName("tanh_u"), svd.u);
Output tanh_v = ops::Tanh(root.WithOpName("tanh_v"), svd.v);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
std::vector<string> tensor_names;
tensor_names.push_back("svd:0");
tensor_names.push_back("svd:2");
TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f));
VLOG(1) << graph->ToGraphDefDebug().DebugString();
auto m_svd_s = Out(0, NodeWith(Name("svd")));
auto m_svd_s_with_jitter = Out(
NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_s)));
auto m_svd_u = Out(1, NodeWith(Name("svd")));
auto m_svd_v = Out(2, NodeWith(Name("svd")));
auto m_svd_v_with_jitter = Out(
NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_v)));
auto m_tanh_s = NodeWith(Op("Tanh"), Inputs(m_svd_s_with_jitter));
auto m_tanh_u = NodeWith(Op("Tanh"), Inputs(m_svd_u));
auto m_tanh_v = NodeWith(Op("Tanh"), Inputs(m_svd_v_with_jitter));
Node* tanh_s_transformed = testing::FindNodeByName(graph.get(), "tanh_s");
ASSERT_NE(tanh_s_transformed, nullptr);
Node* tanh_u_transformed = testing::FindNodeByName(graph.get(), "tanh_u");
ASSERT_NE(tanh_u_transformed, nullptr);
Node* tanh_v_transformed = testing::FindNodeByName(graph.get(), "tanh_v");
ASSERT_NE(tanh_v_transformed, nullptr);
EXPECT_THAT(tanh_s_transformed, m_tanh_s);
EXPECT_THAT(tanh_u_transformed, m_tanh_u);
EXPECT_THAT(tanh_v_transformed, m_tanh_v);
}
} // namespace
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
@ -31,6 +32,9 @@ namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
EncapsulateXlaComputationsPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25,
IntroduceFloatingPointJitterPass);
// from
// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
// FunctionalizeControlFlowPass: 27

View File

@ -77,6 +77,8 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
}
switch (tensor.dtype()) {
case DT_HALF:
return CompareTensor<Eigen::half>(tensor, expected_tensor, listener);
case DT_FLOAT:
return CompareTensor<float>(tensor, expected_tensor, listener);
case DT_DOUBLE: