Merge commit for internal changes

This commit is contained in:
Vijay Vasudevan 2017-05-05 16:33:39 -07:00
commit 9dd8e7aec9
152 changed files with 6598 additions and 1267 deletions

3
.gitignore vendored
View File

@ -7,11 +7,8 @@ node_modules
/bazel_pip /bazel_pip
/third_party/eigen3/mkl_include /third_party/eigen3/mkl_include
/third_party/mkl/* /third_party/mkl/*
/third_party/py/numpy/numpy_include
/tools/python_bin_path.sh /tools/python_bin_path.sh
/tools/git/gen /tools/git/gen
/util/python/python_include
/util/python/python_lib
/pip_test /pip_test
/_python_build /_python_build
*.pyc *.pyc

View File

@ -263,6 +263,7 @@ filegroup(
"//tensorflow/contrib/seq2seq:all_files", "//tensorflow/contrib/seq2seq:all_files",
"//tensorflow/contrib/session_bundle:all_files", "//tensorflow/contrib/session_bundle:all_files",
"//tensorflow/contrib/session_bundle/example:all_files", "//tensorflow/contrib/session_bundle/example:all_files",
"//tensorflow/contrib/signal:all_files",
"//tensorflow/contrib/slim:all_files", "//tensorflow/contrib/slim:all_files",
"//tensorflow/contrib/slim/python/slim/data:all_files", "//tensorflow/contrib/slim/python/slim/data:all_files",
"//tensorflow/contrib/slim/python/slim/nets:all_files", "//tensorflow/contrib/slim/python/slim/nets:all_files",
@ -326,6 +327,48 @@ filegroup(
"//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/backend:all_files",
"//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/backend/event_processing:all_files",
"//tensorflow/tensorboard/components:all_files", "//tensorflow/tensorboard/components:all_files",
"//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_backend:all_files",
"//tensorflow/tensorboard/components/tf_backend_d3v4:all_files",
"//tensorflow/tensorboard/components/tf_color_scale:all_files",
"//tensorflow/tensorboard/components/tf_color_scale/demo:all_files",
"//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files",
"//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files",
"//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_globals:all_files",
"//tensorflow/tensorboard/components/tf_globals_d3v4:all_files",
"//tensorflow/tensorboard/components/tf_graph_common:all_files",
"//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_imports:all_files",
"//tensorflow/tensorboard/components/tf_imports_d3v4:all_files",
"//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/tf_storage:all_files",
"//tensorflow/tensorboard/components/tf_storage_d3v4:all_files",
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
"//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files",
"//tensorflow/tensorboard/components/vz_data_summary:all_files",
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
"//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files",
"//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files",
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
"//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files",
"//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files",
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
"//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files",
"//tensorflow/tensorboard/components/vz_projector:all_files",
"//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
"//tensorflow/tensorboard/components/vz_sorting:all_files",
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
"//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files",
"//tensorflow/tensorboard/lib:all_files", "//tensorflow/tensorboard/lib:all_files",
"//tensorflow/tensorboard/plugins:all_files", "//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files", "//tensorflow/tensorboard/plugins/projector:all_files",

View File

@ -28,11 +28,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
namespace op = xla::testing::opcode_matchers;
namespace xla { namespace xla {
namespace { namespace {
namespace op = xla::testing::opcode_matchers;
using ::testing::_;
class HloRematerializationTest : public HloTestBase { class HloRematerializationTest : public HloTestBase {
protected: protected:
// Creates and returns a computation which can benefit from // Creates and returns a computation which can benefit from
@ -145,11 +147,9 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// Find and save the original broadcast instruction which should be // Find and save the original broadcast instruction which should be
// rematerialized. // rematerialized.
const HloInstruction* slice = computation->root_instruction(); const HloInstruction* slice = computation->root_instruction();
ASSERT_EQ(HloOpcode::kSlice, slice->opcode()); ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
const HloInstruction* concat = slice->operand(0); const HloInstruction* concat = slice->operand(0);
ASSERT_EQ(HloOpcode::kConcatenate, concat->opcode());
const HloInstruction* bcast = concat->operand(0); const HloInstruction* bcast = concat->operand(0);
ASSERT_EQ(HloOpcode::kBroadcast, bcast->opcode());
SequentialHloOrdering::HloModuleSequence sequence; SequentialHloOrdering::HloModuleSequence sequence;
// Computation requires 16KB without rematerialization, but uses only 12KB // Computation requires 16KB without rematerialization, but uses only 12KB
@ -165,8 +165,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The broadcast should have been rematerialized. // The broadcast should have been rematerialized.
const HloInstruction* remat_bcast = concat->operand(0); const HloInstruction* remat_bcast = concat->operand(0);
EXPECT_EQ(HloOpcode::kBroadcast, remat_bcast->opcode()); EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
EXPECT_NE(bcast, remat_bcast);
// The rematerialized broadcast should be immediate before the concat in the // The rematerialized broadcast should be immediate before the concat in the
// sequence. // sequence.

View File

@ -68,9 +68,8 @@ void CleanNodeName(string* name) {
} }
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
LOG(INFO) << "Adding computation " << computation.name(); VLOG(2) << "Adding computation " << computation.name();
for (auto embedded : computation.MakeEmbeddedComputationsList()) { for (auto embedded : computation.MakeEmbeddedComputationsList()) {
LOG(INFO) << "Adding embedded computation " << embedded->name();
for (auto& instruction : embedded->instructions()) { for (auto& instruction : embedded->instructions()) {
TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
} }

View File

@ -85,7 +85,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
maxval=math_ops.cast(num_data, dtypes.int64), maxval=math_ops.cast(num_data, dtypes.int64),
seed=random_seed, seed=random_seed,
dtype=dtypes.int64) dtype=dtypes.int64)
indices = indices % math_ops.cast(num_data, dtypes.int64) indices %= math_ops.cast(num_data, dtypes.int64)
clusters_init = embedding_lookup(data, indices, partition_strategy='div') clusters_init = embedding_lookup(data, indices, partition_strategy='div')
return clusters_init return clusters_init

View File

@ -35,8 +35,8 @@ class GridRNNCellTest(test.TestCase):
def testGrid2BasicLSTMCell(self): def testGrid2BasicLSTMCell(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.2)) as root_scope: 'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@ -51,21 +51,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2)) self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x:
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), np.array([[1., 1., 1.]]),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
self.assertEqual(res_s[1].c.shape, (1, 2)) self.assertEqual(res_s[1].c.shape, (1, 2))
self.assertEqual(res_s[1].h.shape, (1, 2)) self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g, ([[0.36617181, 0.36617181]], )) self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],))
self.assertAllClose(res_s, (([[0.71053141, 0.71053141]], self.assertAllClose(
[[0.36617181, 0.36617181]]), res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
([[0.72320831, 0.80555487]], ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
[[0.39102408, 0.42150158]])))
# emulate a loop through the input sequence, # emulate a loop through the input sequence,
# where we call cell() multiple times # where we call cell() multiple times
@ -78,22 +79,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s2[1].h.get_shape(), (1, 2)) self.assertEqual(s2[1].h.get_shape(), (1, 2))
res_g2, res_s2 = sess.run([g2, s2], res_g2, res_s2 = sess.run([g2, s2],
{x: np.array([[2., 2., 2.]]), m: res_s}) {x: np.array([[2., 2., 2.]]),
m: res_s})
self.assertEqual(res_g2[0].shape, (1, 2)) self.assertEqual(res_g2[0].shape, (1, 2))
self.assertEqual(res_s2[0].c.shape, (1, 2)) self.assertEqual(res_s2[0].c.shape, (1, 2))
self.assertEqual(res_s2[0].h.shape, (1, 2)) self.assertEqual(res_s2[0].h.shape, (1, 2))
self.assertEqual(res_s2[1].c.shape, (1, 2)) self.assertEqual(res_s2[1].c.shape, (1, 2))
self.assertEqual(res_s2[1].h.shape, (1, 2)) self.assertEqual(res_s2[1].h.shape, (1, 2))
self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]]) self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
self.assertAllClose(res_s2, (([[1.40469193, 1.40469193]], self.assertAllClose(
[[0.58847463, 0.58847463]]), res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]),
([[0.97726452, 1.04626071]], ([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]])))
[[0.4927212, 0.51137757]])))
def testGrid2BasicLSTMCellTied(self): def testGrid2BasicLSTMCellTied(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope( with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.2)): 'root', initializer=init_ops.constant_initializer(0.2)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@ -108,10 +109,12 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2)) self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x:
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), np.array([[1., 1., 1.]]),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
@ -119,29 +122,27 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[1].h.shape, (1, 2)) self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]]) self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
self.assertAllClose(res_s, (([[0.71053141, 0.71053141]], self.assertAllClose(
[[0.36617181, 0.36617181]]), res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
([[0.72320831, 0.80555487]], ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
[[0.39102408, 0.42150158]])))
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s}) res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]]) self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
self.assertAllClose(res_s, (([[0.71200621, 0.71200621]], self.assertAllClose(
[[0.36703536, 0.36703536]]), res_s, (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]),
([[0.80941606, 0.87550586]], ([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]])))
[[0.40108523, 0.42199609]])))
def testGrid2BasicLSTMCellWithRelu(self): def testGrid2BasicLSTMCellWithRelu(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.2)): 'root', initializer=init_ops.constant_initializer(0.2)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),) m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid2BasicLSTMCell( cell = grid_rnn_cell.Grid2BasicLSTMCell(
2, tied=False, non_recurrent_fn=nn_ops.relu) 2, tied=False, non_recurrent_fn=nn_ops.relu)
self.assertEqual(cell.state_size, ((2, 2), )) self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m) g, s = cell(x, m)
self.assertEqual(g[0].get_shape(), (1, 2)) self.assertEqual(g[0].get_shape(), (1, 2))
@ -149,21 +150,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2)) self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x: np.array([[1., 1., 1.]]),
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]]) self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
self.assertAllClose(res_s, (([[0.29530135, 0.37520045]], self.assertAllClose(res_s, (([[0.29530135, 0.37520045]],
[[0.17044567, 0.21292259]]), )) [[0.17044567, 0.21292259]]),))
"""LSTMCell """LSTMCell
""" """
def testGrid2LSTMCell(self): def testGrid2LSTMCell(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@ -178,10 +180,12 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2)) self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x:
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), np.array([[1., 1., 1.]]),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
@ -189,15 +193,14 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[1].h.shape, (1, 2)) self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]]) self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
self.assertAllClose(res_s, (([[2.41515064, 2.41515064]], self.assertAllClose(
[[0.95686918, 0.95686918]]), res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
([[1.38917875, 1.49043763]], ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
[[0.83884692, 0.86036491]])))
def testGrid2LSTMCellTied(self): def testGrid2LSTMCellTied(self):
with self.test_session(use_gpu=False) as sess: with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2]))) (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@ -212,10 +215,12 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2)) self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x:
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), np.array([[1., 1., 1.]]),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
@ -223,15 +228,14 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[1].h.shape, (1, 2)) self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]]) self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
self.assertAllClose(res_s, (([[2.41515064, 2.41515064]], self.assertAllClose(
[[0.95686918, 0.95686918]]), res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
([[1.38917875, 1.49043763]], ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
[[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self): def testGrid2LSTMCellWithRelu(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),) m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid2LSTMCell( cell = grid_rnn_cell.Grid2LSTMCell(
@ -244,21 +248,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2)) self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x: np.array([[1., 1., 1.]]),
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]]) self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
self.assertAllClose(res_s, (([[0.92270052, 1.02325559]], self.assertAllClose(res_s, (([[0.92270052, 1.02325559]],
[[0.66159075, 0.70475441]]), )) [[0.66159075, 0.70475441]]),))
"""RNNCell """RNNCell
""" """
def testGrid2BasicRNNCell(self): def testGrid2BasicRNNCell(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2]) x = array_ops.zeros([2, 2])
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2])) m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
cell = grid_rnn_cell.Grid2BasicRNNCell(2) cell = grid_rnn_cell.Grid2BasicRNNCell(2)
@ -270,26 +275,26 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].get_shape(), (2, 2)) self.assertEqual(s[1].get_shape(), (2, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1.], [2., 2.]]), x:
m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[1., 1.], [2., 2.]]),
np.array([[0.1, 0.1], [0.2, 0.2]]))}) m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
[0.2, 0.2]]))
})
self.assertEqual(res_g[0].shape, (2, 2)) self.assertEqual(res_g[0].shape, (2, 2))
self.assertEqual(res_s[0].shape, (2, 2)) self.assertEqual(res_s[0].shape, (2, 2))
self.assertEqual(res_s[1].shape, (2, 2)) self.assertEqual(res_s[1].shape, (2, 2))
self.assertAllClose(res_g, ([[0.94685763, 0.94685763], self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
[0.99480951, 0.99480951]], )) [0.99480951, 0.99480951]],))
self.assertAllClose(res_s, self.assertAllClose(
([[0.94685763, 0.94685763], res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
[0.99480951, 0.99480951]], [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
[[0.80049908, 0.80049908],
[0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self): def testGrid2BasicRNNCellTied(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2]) x = array_ops.zeros([2, 2])
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2])) m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True) cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
@ -301,55 +306,55 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].get_shape(), (2, 2)) self.assertEqual(s[1].get_shape(), (2, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1.], [2., 2.]]), x:
m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[1., 1.], [2., 2.]]),
np.array([[0.1, 0.1], [0.2, 0.2]]))}) m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
[0.2, 0.2]]))
})
self.assertEqual(res_g[0].shape, (2, 2)) self.assertEqual(res_g[0].shape, (2, 2))
self.assertEqual(res_s[0].shape, (2, 2)) self.assertEqual(res_s[0].shape, (2, 2))
self.assertEqual(res_s[1].shape, (2, 2)) self.assertEqual(res_s[1].shape, (2, 2))
self.assertAllClose(res_g, ([[0.94685763, 0.94685763], self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
[0.99480951, 0.99480951]], )) [0.99480951, 0.99480951]],))
self.assertAllClose(res_s, self.assertAllClose(
([[0.94685763, 0.94685763], res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
[0.99480951, 0.99480951]], [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
[[0.80049908, 0.80049908],
[0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self): def testGrid2BasicRNNCellWithRelu(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2]) x = array_ops.zeros([1, 2])
m = (array_ops.zeros([1, 2]), ) m = (array_ops.zeros([1, 2]),)
cell = grid_rnn_cell.Grid2BasicRNNCell( cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu)
2, non_recurrent_fn=nn_ops.relu) self.assertEqual(cell.state_size, (2,))
self.assertEqual(cell.state_size, (2, ))
g, s = cell(x, m) g, s = cell(x, m)
self.assertEqual(g[0].get_shape(), (1, 2)) self.assertEqual(g[0].get_shape(), (1, 2))
self.assertEqual(s[0].get_shape(), (1, 2)) self.assertEqual(s[0].get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1.]]), res_g, res_s = sess.run(
m: np.array([[0.1, 0.1]])}) [g, s], {x: np.array([[1., 1.]]),
m: np.array([[0.1, 0.1]])})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].shape, (1, 2)) self.assertEqual(res_s[0].shape, (1, 2))
self.assertAllClose(res_g, ([[1.80049896, 1.80049896]], )) self.assertAllClose(res_g, ([[1.80049896, 1.80049896]],))
self.assertAllClose(res_s, ([[0.80049896, 0.80049896]], )) self.assertAllClose(res_s, ([[0.80049896, 0.80049896]],))
"""1-LSTM """1-LSTM
""" """
def testGrid1LSTMCell(self): def testGrid1LSTMCell(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)) as root_scope: 'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), ) m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True) cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
self.assertEqual(cell.state_size, ((2, 2), )) self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m) g, s = cell(x, m)
self.assertEqual(g[0].get_shape(), (1, 2)) self.assertEqual(g[0].get_shape(), (1, 2))
@ -357,17 +362,17 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2)) self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x: np.array([[1., 1., 1.]]),
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )}) m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
self.assertAllClose(res_g, ([[0.91287315, 0.91287315]], )) self.assertAllClose(res_g, ([[0.91287315, 0.91287315]],))
self.assertAllClose(res_s, self.assertAllClose(res_s, (([[2.26285243, 2.26285243]],
(([[2.26285243, 2.26285243]], [[0.91287315, 0.91287315]]),))
[[0.91287315, 0.91287315]]), ))
root_scope.reuse_variables() root_scope.reuse_variables()
@ -383,10 +388,9 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s2[0].c.shape, (1, 2)) self.assertEqual(res_s2[0].c.shape, (1, 2))
self.assertEqual(res_s2[0].h.shape, (1, 2)) self.assertEqual(res_s2[0].h.shape, (1, 2))
self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]], )) self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]],))
self.assertAllClose(res_s2, self.assertAllClose(res_s2, (([[2.79966092, 2.79966092]],
(([[2.79966092, 2.79966092]], [[0.9032144, 0.9032144]]),))
[[0.9032144, 0.9032144]]), ))
g3, s3 = cell(x2, m) g3, s3 = cell(x2, m)
self.assertEqual(g3[0].get_shape(), (1, 2)) self.assertEqual(g3[0].get_shape(), (1, 2))
@ -398,18 +402,17 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_g3[0].shape, (1, 2)) self.assertEqual(res_g3[0].shape, (1, 2))
self.assertEqual(res_s3[0].c.shape, (1, 2)) self.assertEqual(res_s3[0].c.shape, (1, 2))
self.assertEqual(res_s3[0].h.shape, (1, 2)) self.assertEqual(res_s3[0].h.shape, (1, 2))
self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]], )) self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]],))
self.assertAllClose(res_s3, self.assertAllClose(res_s3, (([[3.3529923, 3.3529923]],
(([[3.3529923, 3.3529923]], [[0.92727238, 0.92727238]]),))
[[0.92727238, 0.92727238]]), ))
"""3-LSTM """3-LSTM
""" """
def testGrid3LSTMCell(self): def testGrid3LSTMCell(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), (array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
@ -427,11 +430,13 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[2].h.get_shape(), (1, 2)) self.assertEqual(s[2].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x:
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), np.array([[1., 1., 1.]]),
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
(np.array([[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))}) (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), (np.array(
[[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))
})
self.assertEqual(res_g[0].shape, (1, 2)) self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
@ -440,21 +445,19 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[2].c.shape, (1, 2)) self.assertEqual(res_s[2].c.shape, (1, 2))
self.assertEqual(res_s[2].h.shape, (1, 2)) self.assertEqual(res_s[2].h.shape, (1, 2))
self.assertAllClose(res_g, ([[0.96892911, 0.96892911]], )) self.assertAllClose(res_g, ([[0.96892911, 0.96892911]],))
self.assertAllClose(res_s, (([[2.45227885, 2.45227885]], self.assertAllClose(
[[0.96892911, 0.96892911]]), res_s, (([[2.45227885, 2.45227885]], [[0.96892911, 0.96892911]]),
([[1.33592629, 1.4373529]], ([[1.33592629, 1.4373529]], [[0.80867189, 0.83247656]]),
[[0.80867189, 0.83247656]]), ([[0.7317788, 0.63205892]], [[0.56548983, 0.50446129]])))
([[0.7317788, 0.63205892]],
[[0.56548983, 0.50446129]])))
"""Edge cases """Edge cases
""" """
def testGridRNNEdgeCasesLikeRelu(self): def testGridRNNEdgeCasesLikeRelu(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2]) x = array_ops.zeros([3, 2])
m = () m = ()
@ -471,18 +474,18 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s, ()) self.assertEqual(s, ())
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s],
[g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])}) {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
self.assertEqual(res_g[0].shape, (3, 2)) self.assertEqual(res_g[0].shape, (3, 2))
self.assertEqual(res_s, ()) self.assertEqual(res_s, ())
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]], )) self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self): def testGridRNNEdgeCasesNoOutput(self):
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2]) x = array_ops.zeros([1, 2])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), ) m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
# This cell produces no output # This cell produces no output
cell = grid_rnn_cell.GridRNNCell( cell = grid_rnn_cell.GridRNNCell(
@ -498,9 +501,10 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2)) self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res_g, res_s = sess.run( res_g, res_s = sess.run([g, s], {
[g, s], {x: np.array([[1., 1.]]), x: np.array([[1., 1.]]),
m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])), )}) m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])),)
})
self.assertEqual(res_g, ()) self.assertEqual(res_g, ())
self.assertEqual(res_s[0].c.shape, (1, 2)) self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2)) self.assertEqual(res_s[0].h.shape, (1, 2))
@ -561,8 +565,9 @@ class GridRNNCellTest(test.TestCase):
cell = grid_rnn_cell.Grid2LSTMCell( cell = grid_rnn_cell.Grid2LSTMCell(
num_units=num_units, non_recurrent_fn=nn_ops.relu) num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [array_ops.placeholder( inputs = max_length * [
dtypes.float32, shape=(batch_size, input_size))] array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
@ -600,8 +605,9 @@ class GridRNNCellTest(test.TestCase):
cell = grid_rnn_cell.Grid3LSTMCell( cell = grid_rnn_cell.Grid3LSTMCell(
num_units=num_units, non_recurrent_fn=nn_ops.relu) num_units=num_units, non_recurrent_fn=nn_ops.relu)
inputs = max_length * [array_ops.placeholder( inputs = max_length * [
dtypes.float32, shape=(batch_size, input_size))] array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
@ -671,19 +677,17 @@ class GridRNNCellTest(test.TestCase):
self.assertTrue(np.all(np.isfinite(v))) self.assertTrue(np.all(np.isfinite(v)))
def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self): def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
"""Test for #4296 """Test for #4296."""
"""
input_size = 5 input_size = 5
max_length = 6 # unrolled up to this length max_length = 6 # unrolled up to this length
num_units = 2 num_units = 2
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units) cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
inputs = max_length * [ inputs = max_length * [
array_ops.placeholder( array_ops.placeholder(dtypes.float32, shape=(None, input_size))
dtypes.float32, shape=(None, input_size))
] ]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32) outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
@ -700,8 +704,7 @@ class GridRNNCellTest(test.TestCase):
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size)) input_value = np.ones((3, input_size))
values = sess.run(outputs + [state], values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
feed_dict={inputs[0]: input_value})
for tp in values[:-1]: for tp in values[:-1]:
for v in tp: for v in tp:
self.assertTrue(np.all(np.isfinite(v))) self.assertTrue(np.all(np.isfinite(v)))
@ -710,18 +713,15 @@ class GridRNNCellTest(test.TestCase):
for v in st: for v in st:
self.assertTrue(np.all(np.isfinite(v))) self.assertTrue(np.all(np.isfinite(v)))
def testGrid2LSTMCellLegacy(self): def testGrid2LSTMCellLegacy(self):
"""Test for legacy case (when state_is_tuple=False) """Test for legacy case (when state_is_tuple=False)."""
"""
with self.test_session() as sess: with self.test_session() as sess:
with variable_scope.variable_scope('root', with variable_scope.variable_scope(
initializer=init_ops.constant_initializer(0.5)): 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3]) x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 8]) m = array_ops.zeros([1, 8])
cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True, cell = grid_rnn_cell.Grid2LSTMCell(
state_is_tuple=False, 2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False)
output_is_tuple=False)
self.assertEqual(cell.state_size, 8) self.assertEqual(cell.state_size, 8)
g, s = cell(x, m) g, s = cell(x, m)
@ -729,15 +729,17 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s.get_shape(), (1, 8)) self.assertEqual(s.get_shape(), (1, 8))
sess.run([variables.global_variables_initializer()]) sess.run([variables.global_variables_initializer()])
res = sess.run( res = sess.run([g, s], {
[g, s], {x: np.array([[1., 1., 1.]]), x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])}) m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
})
self.assertEqual(res[0].shape, (1, 2)) self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8)) self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.95686918, 0.95686918]]) self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918, self.assertAllClose(res[1], [[
0.95686918, 1.38917875, 1.49043763, 2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875,
0.83884692, 0.86036491]]) 1.49043763, 0.83884692, 0.86036491
]])
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -102,16 +102,16 @@ class GridRNNCell(rnn.RNNCell):
output_is_tuple: If True, the output is a tuple of the outputs of the output_is_tuple: If True, the output is a tuple of the outputs of the
recurrent dimensions. If False, they are concatenated along the recurrent dimensions. If False, they are concatenated along the
column axis. The later behavior will soon be deprecated. column axis. The later behavior will soon be deprecated.
Raises: Raises:
TypeError: if cell_fn does not return an RNNCell instance. TypeError: if cell_fn does not return an RNNCell instance.
""" """
if not state_is_tuple: if not state_is_tuple:
logging.warning("%s: Using a concatenated state is slower and will " logging.warning('%s: Using a concatenated state is slower and will '
"soon be deprecated. Use state_is_tuple=True.", self) 'soon be deprecated. Use state_is_tuple=True.', self)
if not output_is_tuple: if not output_is_tuple:
logging.warning("%s: Using a concatenated output is slower and will" logging.warning('%s: Using a concatenated output is slower and will'
"soon be deprecated. Use output_is_tuple=True.", self) 'soon be deprecated. Use output_is_tuple=True.', self)
if num_dims < 1: if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims)) raise ValueError('dims must be >= 1: {}'.format(num_dims))
@ -126,9 +126,7 @@ class GridRNNCell(rnn.RNNCell):
if cell_fn is None: if cell_fn is None:
my_cell_fn = functools.partial( my_cell_fn = functools.partial(
rnn.LSTMCell, rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
num_units=num_units,
state_is_tuple=state_is_tuple)
else: else:
my_cell_fn = lambda: cell_fn(num_units) my_cell_fn = lambda: cell_fn(num_units)
if tied: if tied:
@ -136,9 +134,8 @@ class GridRNNCell(rnn.RNNCell):
else: else:
self._cells = [my_cell_fn() for _ in range(num_dims)] self._cells = [my_cell_fn() for _ in range(num_dims)]
if not isinstance(self._cells[0], rnn.RNNCell): if not isinstance(self._cells[0], rnn.RNNCell):
raise TypeError( raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
'cell_fn must return an RNNCell instance, saw: %s' type(self._cells[0]))
% type(self._cells[0]))
if self._output_is_tuple: if self._output_is_tuple:
self._output_size = tuple(self._cells[0].output_size self._output_size = tuple(self._cells[0].output_size
@ -201,26 +198,36 @@ class GridRNNCell(rnn.RNNCell):
if self._output_is_tuple: if self._output_is_tuple:
output = tuple(output_tensors) output = tuple(output_tensors)
else: else:
if len(output_tensors) == 0: if output_tensors:
output = array_ops.zeros([0, 0], dtype)
else:
output = array_ops.concat(output_tensors, 1) output = array_ops.concat(output_tensors, 1)
else:
output = array_ops.zeros([0, 0], dtype)
if self._state_is_tuple: if self._state_is_tuple:
states = tuple(new_state[i] for i in self._config.recurrents) states = tuple(new_state[i] for i in self._config.recurrents)
else: else:
# concat each state first, then flatten the whole thing # concat each state first, then flatten the whole thing
state_tensors = [x for i in self._config.recurrents state_tensors = [
for x in new_state[i]] x for i in self._config.recurrents for x in new_state[i]
if len(state_tensors) == 0: ]
states = array_ops.zeros([0, 0], dtype) if state_tensors:
else:
states = array_ops.concat(state_tensors, 1) states = array_ops.concat(state_tensors, 1)
else:
states = array_ops.zeros([0, 0], dtype)
return output, states return output, states
def _extract_states(self, state): def _extract_states(self, state):
"""Extract the cell and previous output tensors from the given state """Extract the cell and previous output tensors from the given state.
Args:
state: The RNN state.
Returns:
Tuple of the cell value, previous output, and cell_output_size.
Raises:
ValueError: If len(self._config.recurrents) != len(state).
""" """
conf = self._config conf = self._config
@ -238,8 +245,8 @@ class GridRNNCell(rnn.RNNCell):
if self._state_is_tuple: if self._state_is_tuple:
if len(conf.recurrents) != len(state): if len(conf.recurrents) != len(state):
raise ValueError("Expected state as a tuple of {} " raise ValueError('Expected state as a tuple of {} '
"element".format(len(conf.recurrents))) 'element'.format(len(conf.recurrents)))
for recurrent_dim, recurrent_state in zip(conf.recurrents, state): for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
if cell_output_size > 0: if cell_output_size > 0:
@ -247,49 +254,62 @@ class GridRNNCell(rnn.RNNCell):
else: else:
m_prev[recurrent_dim] = recurrent_state m_prev[recurrent_dim] = recurrent_state
else: else:
for recurrent_dim, start_idx in zip(conf.recurrents, range( for recurrent_dim, start_idx in zip(conf.recurrents,
0, self.state_size, total_cell_state_size)): range(0, self.state_size,
total_cell_state_size)):
if cell_output_size > 0: if cell_output_size > 0:
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units]) [-1, conf.num_units])
m_prev[recurrent_dim] = array_ops.slice( m_prev[recurrent_dim] = array_ops.slice(
state, [0, start_idx + conf.num_units], [-1, cell_output_size]) state, [0, start_idx + conf.num_units], [-1, cell_output_size])
else: else:
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units]) [-1, conf.num_units])
return c_prev, m_prev, cell_output_size return c_prev, m_prev, cell_output_size
def _project_input(self, inputs, c_prev, m_prev, with_c): def _project_input(self, inputs, c_prev, m_prev, with_c):
"""Fills in c_prev and m_prev with projected input, for input dimensions """Fills in c_prev and m_prev with projected input, for input dimensions.
Args:
inputs: inputs tensor
c_prev: cell value
m_prev: previous output
with_c: boolean; whether to include project_c.
Raises:
ValueError: if len(self._config.input) != len(inputs)
""" """
conf = self._config conf = self._config
if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
and len(conf.inputs) > 0): conf.inputs):
if isinstance(inputs, tuple): if isinstance(inputs, tuple):
if len(conf.inputs) != len(inputs): if len(conf.inputs) != len(inputs):
raise ValueError("Expect inputs as a tuple of {} " raise ValueError('Expect inputs as a tuple of {} '
"tensors".format(len(conf.inputs))) 'tensors'.format(len(conf.inputs)))
input_splits = inputs input_splits = inputs
else: else:
input_splits = array_ops.split( input_splits = array_ops.split(
value=inputs, num_or_size_splits=len(conf.inputs), axis=1) value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
input_sz = input_splits[0].get_shape().with_rank(2)[1].value input_sz = input_splits[0].get_shape().with_rank(2)[1].value
for i, j in enumerate(conf.inputs): for i, j in enumerate(conf.inputs):
input_project_m = vs.get_variable( input_project_m = vs.get_variable(
'project_m_{}'.format(j), [input_sz, conf.num_units], 'project_m_{}'.format(j), [input_sz, conf.num_units],
dtype=inputs.dtype) dtype=inputs.dtype)
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
if with_c: if with_c:
input_project_c = vs.get_variable( input_project_c = vs.get_variable(
'project_c_{}'.format(j), [input_sz, conf.num_units], 'project_c_{}'.format(j), [input_sz, conf.num_units],
dtype=inputs.dtype) dtype=inputs.dtype)
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
def _cell_state_size(self): def _cell_state_size(self):
"""Total size of the state of the inner cell used in this grid """Total size of the state of the inner cell used in this grid.
Returns:
Total size of the state of the inner cell.
""" """
state_sizes = self._cells[0].state_size state_sizes = self._cells[0].state_size
if isinstance(state_sizes, tuple): if isinstance(state_sizes, tuple):
@ -306,10 +326,15 @@ class Grid1BasicRNNCell(GridRNNCell):
def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True): def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
super(Grid1BasicRNNCell, self).__init__( super(Grid1BasicRNNCell, self).__init__(
num_units=num_units, num_dims=1, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=False, num_dims=1,
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), input_dims=0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) output_dims=0,
priority_dims=0,
tied=False,
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid2BasicRNNCell(GridRNNCell): class Grid2BasicRNNCell(GridRNNCell):
@ -322,38 +347,56 @@ class Grid2BasicRNNCell(GridRNNCell):
specified. specified.
""" """
def __init__(self, num_units, tied=False, non_recurrent_fn=None, def __init__(self,
state_is_tuple=True, output_is_tuple=True): num_units,
tied=False,
non_recurrent_fn=None,
state_is_tuple=True,
output_is_tuple=True):
super(Grid2BasicRNNCell, self).__init__( super(Grid2BasicRNNCell, self).__init__(
num_units=num_units, num_dims=2, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=tied, num_dims=2,
non_recurrent_dims=None if non_recurrent_fn is None else 0, input_dims=0,
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), output_dims=0,
non_recurrent_fn=non_recurrent_fn, priority_dims=0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
non_recurrent_fn=non_recurrent_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid1BasicLSTMCell(GridRNNCell): class Grid1BasicLSTMCell(GridRNNCell):
"""1D BasicLSTM cell""" """1D BasicLSTM cell."""
def __init__(self, num_units, forget_bias=1, def __init__(self,
state_is_tuple=True, output_is_tuple=True): num_units,
forget_bias=1,
state_is_tuple=True,
output_is_tuple=True):
def cell_fn(n):
return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid1BasicLSTMCell, self).__init__( super(Grid1BasicLSTMCell, self).__init__(
num_units=num_units, num_dims=1, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=False, num_dims=1,
cell_fn=lambda n: rnn.BasicLSTMCell( input_dims=0,
num_units=n, forget_bias=forget_bias), output_dims=0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) priority_dims=0,
tied=False,
cell_fn=cell_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid2BasicLSTMCell(GridRNNCell): class Grid2BasicLSTMCell(GridRNNCell):
"""2D BasicLSTM cell """2D BasicLSTM cell.
This creates a 2D cell which receives input and gives output in the first This creates a 2D cell which receives input and gives output in the first
dimension. dimension.
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
specified. specified.
""" """
def __init__(self, def __init__(self,
@ -363,36 +406,53 @@ class Grid2BasicLSTMCell(GridRNNCell):
forget_bias=1, forget_bias=1,
state_is_tuple=True, state_is_tuple=True,
output_is_tuple=True): output_is_tuple=True):
def cell_fn(n):
return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid2BasicLSTMCell, self).__init__( super(Grid2BasicLSTMCell, self).__init__(
num_units=num_units, num_dims=2, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=tied, num_dims=2,
non_recurrent_dims=None if non_recurrent_fn is None else 0, input_dims=0,
cell_fn=lambda n: rnn.BasicLSTMCell( output_dims=0,
num_units=n, forget_bias=forget_bias), priority_dims=0,
non_recurrent_fn=non_recurrent_fn, tied=tied,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=cell_fn,
non_recurrent_fn=non_recurrent_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid1LSTMCell(GridRNNCell): class Grid1LSTMCell(GridRNNCell):
"""1D LSTM cell """1D LSTM cell.
This is different from Grid1BasicLSTMCell because it gives options to This is different from Grid1BasicLSTMCell because it gives options to
specify the forget bias and enabling peepholes specify the forget bias and enabling peepholes.
""" """
def __init__(self, num_units, use_peepholes=False, forget_bias=1.0, def __init__(self,
state_is_tuple=True, output_is_tuple=True): num_units,
use_peepholes=False,
forget_bias=1.0,
state_is_tuple=True,
output_is_tuple=True):
def cell_fn(n):
return rnn.LSTMCell(
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
super(Grid1LSTMCell, self).__init__( super(Grid1LSTMCell, self).__init__(
num_units=num_units, num_dims=1, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, num_dims=1,
cell_fn=lambda n: rnn.LSTMCell( input_dims=0,
num_units=n, use_peepholes=use_peepholes, output_dims=0,
forget_bias=forget_bias), priority_dims=0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) cell_fn=cell_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid2LSTMCell(GridRNNCell): class Grid2LSTMCell(GridRNNCell):
"""2D LSTM cell """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first This creates a 2D cell which receives input and gives output in the first
dimension. dimension.
@ -408,19 +468,27 @@ class Grid2LSTMCell(GridRNNCell):
forget_bias=1.0, forget_bias=1.0,
state_is_tuple=True, state_is_tuple=True,
output_is_tuple=True): output_is_tuple=True):
def cell_fn(n):
return rnn.LSTMCell(
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
super(Grid2LSTMCell, self).__init__( super(Grid2LSTMCell, self).__init__(
num_units=num_units, num_dims=2, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=tied, num_dims=2,
non_recurrent_dims=None if non_recurrent_fn is None else 0, input_dims=0,
cell_fn=lambda n: rnn.LSTMCell( output_dims=0,
num_units=n, forget_bias=forget_bias, priority_dims=0,
use_peepholes=use_peepholes), tied=tied,
non_recurrent_fn=non_recurrent_fn, non_recurrent_dims=None if non_recurrent_fn is None else 0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) cell_fn=cell_fn,
non_recurrent_fn=non_recurrent_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid3LSTMCell(GridRNNCell): class Grid3LSTMCell(GridRNNCell):
"""3D BasicLSTM cell """3D BasicLSTM cell.
This creates a 2D cell which receives input and gives output in the first This creates a 2D cell which receives input and gives output in the first
dimension. dimension.
@ -437,19 +505,27 @@ class Grid3LSTMCell(GridRNNCell):
forget_bias=1.0, forget_bias=1.0,
state_is_tuple=True, state_is_tuple=True,
output_is_tuple=True): output_is_tuple=True):
def cell_fn(n):
return rnn.LSTMCell(
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
super(Grid3LSTMCell, self).__init__( super(Grid3LSTMCell, self).__init__(
num_units=num_units, num_dims=3, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=tied, num_dims=3,
non_recurrent_dims=None if non_recurrent_fn is None else 0, input_dims=0,
cell_fn=lambda n: rnn.LSTMCell( output_dims=0,
num_units=n, forget_bias=forget_bias, priority_dims=0,
use_peepholes=use_peepholes), tied=tied,
non_recurrent_fn=non_recurrent_fn, non_recurrent_dims=None if non_recurrent_fn is None else 0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) cell_fn=cell_fn,
non_recurrent_fn=non_recurrent_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
class Grid2GRUCell(GridRNNCell): class Grid2GRUCell(GridRNNCell):
"""2D LSTM cell """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first This creates a 2D cell which receives input and gives output in the first
dimension. dimension.
@ -457,23 +533,31 @@ class Grid2GRUCell(GridRNNCell):
specified. specified.
""" """
def __init__(self, num_units, tied=False, non_recurrent_fn=None, def __init__(self,
state_is_tuple=True, output_is_tuple=True): num_units,
tied=False,
non_recurrent_fn=None,
state_is_tuple=True,
output_is_tuple=True):
super(Grid2GRUCell, self).__init__( super(Grid2GRUCell, self).__init__(
num_units=num_units, num_dims=2, num_units=num_units,
input_dims=0, output_dims=0, priority_dims=0, tied=tied, num_dims=2,
non_recurrent_dims=None if non_recurrent_fn is None else 0, input_dims=0,
cell_fn=lambda n: rnn.GRUCell(num_units=n), output_dims=0,
non_recurrent_fn=non_recurrent_fn, priority_dims=0,
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple) tied=tied,
non_recurrent_dims=None if non_recurrent_fn is None else 0,
cell_fn=lambda n: rnn.GRUCell(num_units=n),
non_recurrent_fn=non_recurrent_fn,
state_is_tuple=state_is_tuple,
output_is_tuple=output_is_tuple)
"""Helpers # Helpers
"""
_GridRNNDimension = namedtuple( _GridRNNDimension = namedtuple('_GridRNNDimension', [
'_GridRNNDimension', 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn']) ])
_GridRNNConfig = namedtuple('_GridRNNConfig', _GridRNNConfig = namedtuple('_GridRNNConfig',
['num_dims', 'dims', 'inputs', 'outputs', ['num_dims', 'dims', 'inputs', 'outputs',
@ -502,23 +586,23 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
rnn_dims = [] rnn_dims = []
for i in range(num_dims): for i in range(num_dims):
rnn_dims.append( rnn_dims.append(
_GridRNNDimension( _GridRNNDimension(
idx=i, idx=i,
is_input=(i in input_dims), is_input=(i in input_dims),
is_output=(i in output_dims), is_output=(i in output_dims),
is_priority=(i in priority_dims), is_priority=(i in priority_dims),
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else non_recurrent_fn=non_recurrent_fn
None)) if i in non_recurrent_dims else None))
return _GridRNNConfig( return _GridRNNConfig(
num_dims=num_dims, num_dims=num_dims,
dims=rnn_dims, dims=rnn_dims,
inputs=input_dims, inputs=input_dims,
outputs=output_dims, outputs=output_dims,
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims], recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
priority=priority_dims, priority=priority_dims,
non_priority=[x for x in range(num_dims) if x not in priority_dims], non_priority=[x for x in range(num_dims) if x not in priority_dims],
tied=tied, tied=tied,
num_units=num_units) num_units=num_units)
def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state, def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
@ -544,8 +628,8 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
m_prev[0].dtype) m_prev[0].dtype)
last_dim_output = (new_output[-1] if new_output[-1] is not None last_dim_output = (new_output[-1]
else m_prev[-1]) if new_output[-1] is not None else m_prev[-1])
for i in dim_indices: for i in dim_indices:
d = conf.dims[i] d = conf.dims[i]
@ -560,12 +644,12 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
vs.get_variable_scope().reuse_variables() vs.get_variable_scope().reuse_variables()
new_output[d.idx] = layers.fully_connected( new_output[d.idx] = layers.fully_connected(
linear_args, linear_args,
num_outputs=conf.num_units, num_outputs=conf.num_units,
activation_fn=d.non_recurrent_fn, activation_fn=d.non_recurrent_fn,
weights_initializer=vs.get_variable_scope().initializer or weights_initializer=(vs.get_variable_scope().initializer or
layers.initializers.xavier_initializer, layers.initializers.xavier_initializer),
weights_regularizer=vs.get_variable_scope().regularizer) weights_regularizer=vs.get_variable_scope().regularizer)
else: else:
if c_prev[i] is not None: if c_prev[i] is not None:
cell_state = (c_prev[i], last_dim_output) cell_state = (c_prev[i], last_dim_output)

View File

@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
using functor::FillProjectiveTransform; using functor::FillProjectiveTransform;
using generator::INTERPOLATION_BILINEAR;
using generator::INTERPOLATION_NEAREST;
using generator::Interpolation;
using generator::ProjectiveGenerator; using generator::ProjectiveGenerator;
template <typename Device, typename T> template <typename Device, typename T>
class ImageProjectiveTransform : public OpKernel { class ImageProjectiveTransform : public OpKernel {
private:
Interpolation interpolation_;
public: public:
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
: OpKernel(ctx) {} string interpolation_str;
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
if (interpolation_str == "NEAREST") {
interpolation_ = INTERPOLATION_NEAREST;
} else if (interpolation_str == "BILINEAR") {
interpolation_ = INTERPOLATION_BILINEAR;
} else {
LOG(FATAL) << "Invalid interpolation " << interpolation_str
<< ". Supported types: NEAREST, BILINEAR";
}
}
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0); const Tensor& images_t = ctx->input(0);
@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
Tensor* output_t; Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
auto output = output_t->tensor<T, 4>(); auto output = output_t->tensor<T, 4>();
const FillProjectiveTransform<Device, T> functor; (FillProjectiveTransform<Device, T>(interpolation_))(
functor(ctx->eigen_device<Device>(), &output, images, transform); ctx->eigen_device<Device>(), &output, images, transform);
} }
}; };

View File

@ -28,6 +28,8 @@ namespace tensorflow {
namespace generator { namespace generator {
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
using Eigen::array; using Eigen::array;
using Eigen::DenseIndex; using Eigen::DenseIndex;
@ -36,20 +38,19 @@ class ProjectiveGenerator {
private: private:
typename TTypes<T, 4>::ConstTensor input_; typename TTypes<T, 4>::ConstTensor input_;
typename TTypes<float>::ConstMatrix transforms_; typename TTypes<float>::ConstMatrix transforms_;
const Interpolation interpolation_;
public: public:
static const int kNumParameters = 8; static const int kNumParameters = 8;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input, ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
typename TTypes<float>::ConstMatrix transforms) typename TTypes<float>::ConstMatrix transforms,
: input_(input), transforms_(transforms) {} const Interpolation interpolation)
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const { operator()(const array<DenseIndex, 4>& coords) const {
array<DenseIndex, 4> input_coords;
input_coords[0] = coords[0];
const int64 output_y = coords[1]; const int64 output_y = coords[1];
const int64 output_x = coords[2]; const int64 output_x = coords[2];
const float* transform = const float* transform =
@ -57,24 +58,73 @@ class ProjectiveGenerator {
? transforms_.data() ? transforms_.data()
: &transforms_.data()[transforms_.dimension(1) * coords[0]]; : &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f; float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
const int64 input_x = std::round( const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) / (transform[0] * output_x + transform[1] * output_y + transform[2]) /
projection); projection;
const int64 input_y = std::round( const float input_y =
(transform[3] * output_x + transform[4] * output_y + transform[5]) / (transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection); projection;
if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x && // TODO(ringwalt): Add a fill value input.
input_x < input_.dimension(2))) { static const T fill_value = T(0);
// TODO(ringwalt): Add a fill value input. switch (interpolation_) {
return T(0); case INTERPOLATION_NEAREST:
// Switch the order of x and y again for indexing into the image.
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
case INTERPOLATION_BILINEAR:
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
fill_value);
} }
input_coords[1] = input_y; }
input_coords[2] = input_x;
input_coords[3] = coords[3]; private:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
nearest_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const {
return read_with_fill_value(batch, DenseIndex(std::round(y)),
DenseIndex(std::round(x)), channel, fill_value);
}
return input_(input_coords); EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
const DenseIndex channel, const T fill_value) const {
const float y_floor = std::floor(y);
const float x_floor = std::floor(x);
const float y_ceil = y_floor + 1;
const float x_ceil = x_floor + 1;
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
const float value_yfloor =
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
DenseIndex(x_floor), channel,
fill_value) +
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
DenseIndex(x_ceil), channel,
fill_value);
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
const float value_yceil =
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
DenseIndex(x_floor), channel,
fill_value) +
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
DenseIndex(x_ceil), channel,
fill_value);
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
const DenseIndex channel, const T fill_value) const {
// batch and channel must be correct, because they are passed unchanged from
// the input.
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
x < input_.dimension(2))
? input_(array<DenseIndex, 4>{batch, y, x, channel})
: fill_value;
} }
}; };
@ -85,6 +135,7 @@ class ProjectiveGenerator {
// some Eigen device code. // some Eigen device code.
namespace functor { namespace functor {
using generator::Interpolation;
using generator::ProjectiveGenerator; using generator::ProjectiveGenerator;
template <typename Device, typename T> template <typename Device, typename T>
@ -92,15 +143,17 @@ struct FillProjectiveTransform {
typedef typename TTypes<T, 4>::Tensor OutputType; typedef typename TTypes<T, 4>::Tensor OutputType;
typedef typename TTypes<T, 4>::ConstTensor InputType; typedef typename TTypes<T, 4>::ConstTensor InputType;
typedef typename TTypes<float, 2>::ConstTensor TransformsType; typedef typename TTypes<float, 2>::ConstTensor TransformsType;
const Interpolation interpolation_;
FillProjectiveTransform() {} FillProjectiveTransform(Interpolation interpolation)
: interpolation_(interpolation) {}
EIGEN_ALWAYS_INLINE EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output, void operator()(const Device& device, OutputType* output,
const InputType& images, const InputType& images,
const TransformsType& transform) const { const TransformsType& transform) const {
ProjectiveGenerator<Device, T> generator(images, transform); output->device(device) = images.generate(
output->device(device) = images.generate(generator); ProjectiveGenerator<Device, T>(images, transform, interpolation_));
} }
}; };

View File

@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc. // TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0). // TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to // TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
// implement "same" and "valid" modes in the Python function. // implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform") REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype") .Input("images: dtype")
.Input("transforms: float32") .Input("transforms: float32")
.Attr("dtype: {uint8, int32, int64, float32, float64}") .Attr("dtype: {uint8, int32, int64, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype") .Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) { .SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0)); c->set_output(0, c->input(0));

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradient_checker
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
@ -111,6 +112,79 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 1, 0, 1], [0, 1, 0, 1],
[0, 1, 1, 1]]) [0, 1, 1, 1]])
def test_bilinear(self):
with self.test_session():
image = constant_op.constant(
[[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 1, 0, 1, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]],
dtypes.float32)
# The following result matches:
# >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
# which uses spline interpolation of order 1, equivalent to bilinear
# interpolation.
self.assertAllClose(
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
[[0.000, 0.000, 0.343, 0.000, 0.000],
[0.000, 0.586, 0.914, 0.586, 0.000],
[0.343, 0.914, 0.000, 0.914, 0.343],
[0.000, 0.586, 0.914, 0.586, 0.000],
[0.000, 0.000, 0.343, 0.000, 0.000]],
atol=0.001)
self.assertAllClose(
image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
[[0, 0, 1, 0, 0],
[0, 1, 1, 1, 0],
[1, 1, 0, 1, 1],
[0, 1, 1, 1, 0],
[0, 0, 1, 0, 0]])
def test_bilinear_uint8(self):
with self.test_session():
image = constant_op.constant(
np.asarray(
[[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 255, 255, 255, 0.0],
[0.0, 255, 0.0, 255, 0.0],
[0.0, 255, 255, 255, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]],
np.uint8),
dtypes.uint8)
# == np.rint((expected image above) * 255)
self.assertAllEqual(
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
[[0.0, 0.0, 87., 0.0, 0.0],
[0.0, 149, 233, 149, 0.0],
[87., 233, 0.0, 233, 87.],
[0.0, 149, 233, 149, 0.0],
[0.0, 0.0, 87., 0.0, 0.0]])
def _test_grad(self, shape_to_test):
with self.test_session():
test_image_shape = shape_to_test
test_image = np.random.randn(*test_image_shape)
test_image_tensor = constant_op.constant(
test_image, shape=test_image_shape)
test_transform = image_ops.angles_to_projective_transforms(
np.pi / 2, 4, 4)
output_shape = test_image_shape
output = image_ops.transform(test_image_tensor, test_transform)
left_err = gradient_checker.compute_gradient_error(
test_image_tensor,
test_image_shape,
output,
output_shape,
x_init_value=test_image)
self.assertLess(left_err, 1e-10)
def test_grad(self):
self._test_grad([16, 16])
self._test_grad([4, 12, 12])
self._test_grad([3, 4, 12, 12])
def _test_grad(self, shape_to_test): def _test_grad(self, shape_to_test):
with self.test_session(): with self.test_session():

View File

@ -24,8 +24,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader from tensorflow.python.platform import resource_loader
_image_ops_so = loader.load_op_library( _image_ops_so = loader.load_op_library(
@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
def rotate(images, angles): def rotate(images, angles, interpolation="NEAREST"):
"""Rotate image(s) by the passed angle(s) in radians. """Rotate image(s) by the passed angle(s) in radians.
Args: Args:
@ -46,6 +46,7 @@ def rotate(images, angles):
(num_rows, num_columns) (HW). (num_rows, num_columns) (HW).
angles: A scalar angle to rotate all images by, or (if images has rank 4) angles: A scalar angle to rotate all images by, or (if images has rank 4)
a vector of length num_images, with an angle for each image in the batch. a vector of length num_images, with an angle for each image in the batch.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns: Returns:
Image(s) with the same type and shape as `images`, rotated by the given Image(s) with the same type and shape as `images`, rotated by the given
@ -70,7 +71,8 @@ def rotate(images, angles):
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None] image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
output = transform( output = transform(
images, images,
angles_to_projective_transforms(angles, image_width, image_height)) angles_to_projective_transforms(angles, image_height, image_width),
interpolation=interpolation)
if len(image_or_images.get_shape()) == 2: if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0] return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3: elif len(image_or_images.get_shape()) == 3:
@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
axis=1) axis=1)
def transform(images, transforms): def transform(images, transforms, interpolation="NEAREST"):
"""Applies the given transform(s) to the image(s). """Applies the given transform(s) to the image(s).
Args: Args:
@ -134,6 +136,7 @@ def transform(images, transforms):
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points. the transform mapping input points to output points.
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns: Returns:
Image(s) with the same type and shape as `images`, with the given Image(s) with the same type and shape as `images`, with the given
@ -163,8 +166,8 @@ def transform(images, transforms):
transforms = transform_or_transforms transforms = transform_or_transforms
else: else:
raise TypeError("Transforms should have rank 1 or 2.") raise TypeError("Transforms should have rank 1 or 2.")
# pylint: disable=protected-access output = gen_image_ops.image_projective_transform(
output = gen_image_ops.image_projective_transform(images, transforms) images, transforms, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2: if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0] return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3: elif len(image_or_images.get_shape()) == 3:
@ -217,8 +220,10 @@ def _transform_matrices_to_flat(transform_matrices):
@ops.RegisterGradient("ImageProjectiveTransform") @ops.RegisterGradient("ImageProjectiveTransform")
def _image_projective_transform_grad(op, grad): def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0] images = op.inputs[0]
transforms = op.inputs[1] transforms = op.inputs[1]
interpolation = op.get_attr("interpolation")
image_or_images = ops.convert_to_tensor(images, name="images") image_or_images = ops.convert_to_tensor(images, name="images")
transform_or_transforms = ops.convert_to_tensor( transform_or_transforms = ops.convert_to_tensor(
@ -245,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
transforms = _flat_transforms_to_matrices(transforms=transforms) transforms = _flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms) inverse = linalg_ops.matrix_inverse(transforms)
transforms = _transform_matrices_to_flat(inverse) transforms = _transform_matrices_to_flat(inverse)
output = gen_image_ops.image_projective_transform(grad, transforms) output = gen_image_ops.image_projective_transform(
grad, transforms, interpolation=interpolation)
if len(image_or_images.get_shape()) == 2: if len(image_or_images.get_shape()) == 2:
return [output[0, :, :, 0], None] return [output[0, :, :, 0], None]
elif len(image_or_images.get_shape()) == 3: elif len(image_or_images.get_shape()) == 3:

View File

@ -0,0 +1,55 @@
# TensorFlow contrib kernel_methods.
This module contains operations and estimators that enable the use of primal
(explicit) kernel methods in TensorFlow. See also the [tutorial](https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/g3doc/tutorial.md) on how to use this module to improve the quality of
classification or regression tasks.
## Kernel Mappers
Implement explicit kernel mapping Ops over tensors. Kernel mappers add
Tensor-In-Tensor-Out (TITO) Ops to the TensorFlow graph. They can be used in
conjunction with other layers or ML models.
Sample usage:
```python
kernel_mapper = tf.contrib.kernel_methods.SomeKernelMapper(...)
out_tensor = kernel_mapper.map(in_tensor)
... # code that consumes out_tensor.
```
Currently, there is a [RandomFourierFeatureMapper]
(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py) implemented that maps dense
input to dense output.
## Kernel-based Estimators
tf.contrib.learn Estimators that use kernel mappers internally to discover
non-linearities in the data. These canned estimators map their input features
using kernel mapper Ops and then apply linear models to the mapped
features. Combining kernel mappers with linear models and different loss
functions leads to a variety of models: linear and non-linear SVMs, linear
regression (with and without kernels) and (multinomial) logistic regression
(with and without kernels).
Currently there is a [KernelLinearClassifier]
(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/kernel_estimators.py) implemented but more pre-packaged estimators
are on the way.
Sample usage:
```python
real_column_a = tf.contrib.layers.real_valued_column(name='real_column_a',...)
sparse_column_b = tf.contrib.layers.sparse_column_with_hash_bucket(...)
kernel_mappers = {real_column_a : [tf.contrib.kernel_methods.SomeKernelMapper(...)]}
optimizer = ...
kernel_classifier = tf.contrib.kernel_methods.KernelLinearClassifier(
feature_columns=[real_column_a, sparse_column_b],
model_dir=...,
optimizer=optimizer,
kernel_mappers=kernel_mappers)
# Construct input_fns
kernel_classifier.fit(...)
kernel_classifier.evaluate(...)
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

View File

@ -0,0 +1,273 @@
# Improving classification using explicit kernel methods
In this tutorial, we demonstrate how combining (explicit) kernel methods with
linear models can drastically increase the latters' quality of predictions
without significantly increasing training and inference times. Currently,
explicit kernel mappings are supported for dense features. Support for sparse
features is in the works.
We will use [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models.
tf.contrib.learn API reduces the boilerplate code one needs to write for
configuring, training and evaluating models and will let us focus on the core
ideas. If you are not familiar with this API, [tf.contrib.learn Quickstart](https://www.tensorflow.org/get_started/tflearn) is a good place to start. We
will use MNIST, a widely-used dataset containing images of handwritten digits
(between 0 and 9). The tutorial consists of the following steps:
* Load and prepare MNIST data for classification.
* Construct a simple linear model, train it and evaluate it on the eval data.
* Replace the linear model with a kernelized linear model, re-train and
re-evaluate.
## Load and prepare MNIST data for classification
The first step is to prepare the data to be fed to the ML models. The following
utility command from tf.contrib.learn loads the MNIST dataset:
```python
data = tf.contrib.learn.datasets.mnist.load_mnist()
```
This loads the entire MNIST dataset (containing 70K samples) and splits it into
train, validation and test data with 55K, 5K and 10K samples respectively. Each
split contains one numpy array for images (with shape [sample_size, 784]) and
one for labels (with shape [sample_size, 1]). In this tutorial, we only use the
train and validation splits (to train and evaluate our models respectively).
In order to feed data to a tf.contrib.learn Estimator, it is helpful to convert
it to Tensors. For this, we will use an `input function` which adds Ops to the
TensorFlow graph that, when executed, create mini-batches of Tensors to be used
downstream. For more background on input functions, check
[Building Input Functions with tf.contrib.learn](https://www.tensorflow.org/get_started/input_fn). In this example, we will use the `tf.train.shuffle_batch` Op which,
besides converting numpy arrays to Tensors, allows us to specify the batch_size
and whether to randomize the input every time the input_fn Ops are executed
(randomization typically expedites convergence during training). The full code
for loading and preparing the data is shown in the snippet below. In this
example, we use mini-batches of size 256 for training and the entire sample (5K
entries) for evaluation. Feel free to experiment with different batch sizes.
```python
import numpy as np
import tensorflow as tf
def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
def _input_fn():
images_batch, labels_batch = tf.train.shuffle_batch(
tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
enqueue_many=True,
num_threads=4)
features_map = {'images': images_batch}
return features_map, labels_batch
return _input_fn
data = tf.contrib.learn.datasets.mnist.load_mnist()
train_input_fn = get_input_fn(data.train, batch_size=256)
eval_input_fn = get_input_fn(data.validation, batch_size=5000)
```
## Training a simple linear model
We can now train a linear model over the MNIST dataset. We will use the [tf.contrib.learn.LinearClassifier](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py) estimator with 10 classes (representing the 10
digits). The input features form a 784-dimensional (dense) vector which can be
specified as follows:
```python
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
```
The full code for constructing, training and evaluating a LinearClassifier
estimator is shown below.
```python
import time
# Specify the feature(s) to be used by the estimator.
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
estimator = tf.contrib.learn.LinearClassifier(feature_columns=[image_column], n_classes=10)
# Train.
start = time.time()
estimator.fit(input_fn=train_input_fn, steps=2000)
end = time.time()
print('Elapsed time: {} seconds'.format(end - start))
# Evaluate and report metrics.
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
print(eval_metrics)
```
On eval data, the loss (i.e., the value of the objective function being
minimized during training) lies between **0.25 and 0.30** (depending on the
parameters used) while the accuracy of the classifier is approximately **92.5%**
(training is randomized so the exact loss and accuracy will vary). Also, the
training time is around 25 seconds (this will also vary based on the machine you
run the code on).
In addition to experimenting with the (training) batch size and the number of
training steps, there are a couple other parameters that can be tuned as well.
For instance, you can change the optimization method used to minimize the loss
by explicitly selecting another optimizer from the collection of
[available optimizers]
(https://www.tensorflow.org/code/tensorflow/python/training).
As an example, the following code constructs a LinearClassifer estimator that
uses the Follow-The-Regularized-Leader (FTRL) optimization strategy with a
specific learning rate and l2-regularization.
```python
optimizer = tf.train.FtrlOptimizer(learning_rate=5.0, l2_regularization_strength=1.0)
estimator = tf.contrib.learn.LinearClassifier(
feature_columns=[image_column], n_classes=10, optimizer=optimizer)
```
Regardless of the values of the parameters, the max accuracy a linear model can
achieve on this dataset caps at around **93%**.
## Using explicit kernel mappings with the linear model.
The relatively high error (~7%) of the linear model over MNIST indicates that
the input data is not linearly separable. We will use explicit kernel mappings
to reduce the classification error.
**Intuition:** The high-level idea is to use a non-linear map to transform the
input space to another feature space (of possibly higher dimension) where the
(transformed) features are (almost) linearly separable and then apply a linear
model on the mapped features. This is shown in the following figure:
<div style="text-align:center">
<img src="./kernel_mapping.png">
</div>
**Technical details overview:** In this example we will use **Random Fourier
Features** (introduced in the ["Random Features for Large-Scale Kernel Machines"]
(https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) paper by
Rahimi and Recht) to map the input data. Random Fourier Features map a vector
$$\mathbf{x} \in \mathbb{R}^d$$ to $$\mathbf{x'} \in \mathbb{R}^D$$ via the
following mapping:
$$
RFFM(\cdot): \mathbb{R}^d \to \mathbb{R}^D, \quad
RFFM(\mathbf{x}) = \cos(\mathbf{\Omega} \cdot \mathbf{x}+ \mathbf{b})
$$
where $$\mathbf{\Omega} \in \mathbb{R}^{D \times d}$$,
$$\mathbf{x} \in \mathbb{R}^d,$$ $$\mathbf{b} \in \mathbb{R}^D$$ and cosine is
applied elementwise.
In this example, the entries of $$\mathbf{\Omega}$$ and $$\mathbf{b}$$ are
sampled from distributions such that the mapping satisfies the following
property:
$$
RFFM(\mathbf{x})^T \cdot RFFM(\mathbf{y}) \approx
e^{-\frac{\|\mathbf{x} - \mathbf{y}\|^2}{2 \sigma^2}}
$$
The right-hand-side quantity of the expression above is known as the RBF (or
Gaussian) kernel function. This function is one of the most-widely used kernel
functions in Machine Learning and measures (implicitly) similarity in a
different (much higher dimensional) space than the original one. See
[Radial basis function kernel](https://en.wikipedia.org/wiki/Radial_basis_function_kernel)
for more details.
**Kernel Classifier:** `tf.contrib.kernel_methods.KernelLinearClassifier` is a
pre-packaged `tf.contrib.learn` estimator that combines the power of explicit
kernel mappings with linear models. Its API is very similar to that of the
LinearClassifier with the additional ability to specify a list of explicit
kernel mappings to be apply to each feature used by the classifier. The
following code snippet demonstrates how to replace LinearClassifier with
KernelLinearClassifier.
```python
# Specify the feature(s) to be used by the estimator. This is identical to the
# code used for the LinearClassifier.
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
optimizer = tf.train.FtrlOptimizer(
learning_rate=50.0, l2_regularization_strength=0.001)
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
kernel_mappers = {image_column: [kernel_mapper]}
estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
# Train.
start = time.time()
estimator.fit(input_fn=train_input_fn, steps=2000)
end = time.time()
print('Elapsed time: {} seconds'.format(end - start))
# Evaluate and report metrics.
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
print(eval_metrics)
```
The only additional parameter passed to `KernelLinearClassifier` is a dictionary
from feature_columns to a list of kernel mappings to be applied to the
corresponding feature column. In this example, the lines
```python
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
kernel_mappers = {image_column: [kernel_mapper]}
estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
```
instruct the classifier to first map the initial 784-dimensional images to
2000-dimensional vectors using random Fourier features and then learn a linear
model on the transformed vectors. Note that, besides the output dimension, there
is one more parameter (stddev) involved. This parameter is the standard
deviation ($$\sigma$$) of the approximated RBF kernel and controls the
similarity measure used in classification. This parameter is typically
determined via hyperparameter tuning.
Running the code above yields a loss of approximately **0.10** while the
accuracy is increased to approximately **97%** on eval data (an increase of 4%
over the plain linear model). The training time hovers around 35 seconds. We can
increase the accuracy even more, by increasing the output dimension of the
mapping and tuning the standard deviation even more.
**On the role of stddev:** The classification quality is very sensitive to the
value of the stddev parameter used to define the similarity measure between the
pairs of input features. The following table shows the accuracy of the
classifier on the eval data for different values of stddev (for all experiments
the output dimension was fixed to 3000). The optimal value is stddev=5.0. Notice
how too small or too high stddev values can dramatically decrease the accuracy
of the classification.
stddev | eval accuracy
:----- | :------------
1.0 | 0.1362
2.0 | 0.4764
4.0 | 0.9654
5.0 | 0.9766
8.0 | 0.9714
16.0 | 0.8878
**On the role of the output dimension:** Intuitively, the larger the output
dimension of the mapping, the closer the inner product of two mapped vectors
approximates the kernel which typically translates to better classification
accuracy. Another way to think about this is that the output dimension equals
the number of weights of the linear model (the larger this dimension, the larger
the "degrees of freedom" of the model). However, after a certain threshold,
higher output dimensions increase the accuracy by very little (while still
increasing the training time). This is shown in the following 2 Figures which
depict the eval accuracy as a function of the output dimension and the training
time respectively.
![image](./acc_vs_outdim.png) ![image](./acc-vs-trn_time.png)
## Explicit Kernel Mappings: summary and practical tips
* Explicit kernel mappings combine the predictive power of non-linear models
with the scalability of linear models.
* Random Fourier Features can be particularly effective for datasets with dense
features.
* The parameters of the kernel mapping are often data-dependent. Model quality
can be very sensitive to these parameters. Use hyperparameter tuning to find the
optimal values.
* If you have multiple numerical features, concatinate them into a single
multi-dimensional one and apply the kernel mapping to the concatenated vector.

View File

@ -121,7 +121,7 @@ def embed_sequence(ids,
`Tensor` of `[batch_size, doc_length, embed_dim]` with embedded sequences. `Tensor` of `[batch_size, doc_length, embed_dim]` with embedded sequences.
Raises: Raises:
ValueError: if `embed_dim` or `vocab_size` are not specified when ValueError: if `embed_dim` or `vocab_size` are not specified when
`reuse` is `None` or `False`. `reuse` is `None` or `False`.
""" """
if not (reuse or (vocab_size and embed_dim)): if not (reuse or (vocab_size and embed_dim)):

View File

@ -131,21 +131,27 @@ import math
import six import six
from tensorflow.contrib import lookup from tensorflow.contrib import lookup
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
@ -291,11 +297,13 @@ class _FeatureColumn(object):
# TODO(b/30410315): Support warm starting in all feature columns. # TODO(b/30410315): Support warm starting in all feature columns.
class _SparseColumn(_FeatureColumn, class _SparseColumn(
collections.namedtuple("_SparseColumn", _FeatureColumn,
["column_name", "is_integerized", fc_core._CategoricalColumn, # pylint: disable=protected-access
"bucket_size", "lookup_config", collections.namedtuple("_SparseColumn", [
"combiner", "dtype"])): "column_name", "is_integerized", "bucket_size", "lookup_config",
"combiner", "dtype"
])):
"""Represents a sparse feature column also known as categorical features. """Represents a sparse feature column also known as categorical features.
Instances of this class are immutable. A sparse column means features are Instances of this class are immutable. A sparse column means features are
@ -426,9 +434,8 @@ class _SparseColumn(_FeatureColumn,
initializer=init_ops.zeros_initializer(), initializer=init_ops.zeros_initializer(),
combiner=self.combiner) combiner=self.combiner)
def _get_input_sparse_tensor(self, columns_to_tensors): def _get_input_sparse_tensor(self, input_tensor):
"""Looks up the input tensor for transformation and sparsify it if dense.""" """sparsify input_tensor if dense."""
input_tensor = columns_to_tensors[self.name]
if not isinstance(input_tensor, sparse_tensor_py.SparseTensor): if not isinstance(input_tensor, sparse_tensor_py.SparseTensor):
# To avoid making any assumptions about which values are to be ignored, # To avoid making any assumptions about which values are to be ignored,
# we set ignore_value to -1 for numeric tensors to avoid excluding valid # we set ignore_value to -1 for numeric tensors to avoid excluding valid
@ -455,18 +462,44 @@ class _SparseColumn(_FeatureColumn,
format(self.name, other_column.name)) format(self.name, other_column.name))
return compatible return compatible
@abc.abstractmethod
def _do_transform(self, input_tensor):
pass
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors[self.name])
columns_to_tensors[self] = self._do_transform(input_tensor)
def _transform_feature(self, inputs):
input_tensor = self._get_input_sparse_tensor(inputs.get(self.name))
return self._do_transform(input_tensor)
@property
def _parse_example_config(self):
return self.config
@property
def _num_buckets(self):
return self.length
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
del weight_collections
del trainable
input_tensor = inputs.get(self)
return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
class _SparseColumnIntegerized(_SparseColumn): class _SparseColumnIntegerized(_SparseColumn):
"""See `sparse_column_with_integerized_feature`.""" """See `sparse_column_with_integerized_feature`."""
def insert_transformed_feature(self, columns_to_tensors): def _do_transform(self, input_tensor):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size, sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size,
name="mod") name="mod")
columns_to_tensors[self] = sparse_tensor_py.SparseTensor( return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
input_tensor.indices, sparse_id_values, input_tensor.dense_shape) input_tensor.dense_shape)
def sparse_column_with_integerized_feature(column_name, def sparse_column_with_integerized_feature(column_name,
@ -517,10 +550,7 @@ def sparse_column_with_integerized_feature(column_name,
class _SparseColumnHashed(_SparseColumn): class _SparseColumnHashed(_SparseColumn):
"""See `sparse_column_with_hash_bucket`.""" """See `sparse_column_with_hash_bucket`."""
def insert_transformed_feature(self, columns_to_tensors): def _do_transform(self, input_tensor):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
if self.dtype.is_integer: if self.dtype.is_integer:
sparse_values = string_ops.as_string(input_tensor.values) sparse_values = string_ops.as_string(input_tensor.values)
else: else:
@ -528,8 +558,8 @@ class _SparseColumnHashed(_SparseColumn):
sparse_id_values = string_ops.string_to_hash_bucket_fast( sparse_id_values = string_ops.string_to_hash_bucket_fast(
sparse_values, self.bucket_size, name="lookup") sparse_values, self.bucket_size, name="lookup")
columns_to_tensors[self] = sparse_tensor_py.SparseTensor( return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
input_tensor.indices, sparse_id_values, input_tensor.dense_shape) input_tensor.dense_shape)
def sparse_column_with_hash_bucket(column_name, def sparse_column_with_hash_bucket(column_name,
@ -572,16 +602,13 @@ def sparse_column_with_hash_bucket(column_name,
class _SparseColumnKeys(_SparseColumn): class _SparseColumnKeys(_SparseColumn):
"""See `sparse_column_with_keys`.""" """See `sparse_column_with_keys`."""
def insert_transformed_feature(self, columns_to_tensors): def _do_transform(self, input_tensor):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
table = lookup.index_table_from_tensor( table = lookup.index_table_from_tensor(
mapping=tuple(self.lookup_config.keys), mapping=tuple(self.lookup_config.keys),
default_value=self.lookup_config.default_value, default_value=self.lookup_config.default_value,
dtype=self.dtype, dtype=self.dtype,
name="lookup") name="lookup")
columns_to_tensors[self] = table.lookup(input_tensor) return table.lookup(input_tensor)
def sparse_column_with_keys( def sparse_column_with_keys(
@ -621,9 +648,7 @@ def sparse_column_with_keys(
class _SparseColumnVocabulary(_SparseColumn): class _SparseColumnVocabulary(_SparseColumn):
"""See `sparse_column_with_vocabulary_file`.""" """See `sparse_column_with_vocabulary_file`."""
def insert_transformed_feature(self, columns_to_tensors): def _do_transform(self, st):
"""Handles sparse column to id conversion."""
st = self._get_input_sparse_tensor(columns_to_tensors)
if self.dtype.is_integer: if self.dtype.is_integer:
sparse_string_values = string_ops.as_string(st.values) sparse_string_values = string_ops.as_string(st.values)
sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices, sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices,
@ -638,7 +663,7 @@ class _SparseColumnVocabulary(_SparseColumn):
vocab_size=self.lookup_config.vocab_size, vocab_size=self.lookup_config.vocab_size,
default_value=self.lookup_config.default_value, default_value=self.lookup_config.default_value,
name=self.name + "_lookup") name=self.name + "_lookup")
columns_to_tensors[self] = table.lookup(sparse_string_tensor) return table.lookup(sparse_string_tensor)
def sparse_column_with_vocabulary_file(column_name, def sparse_column_with_vocabulary_file(column_name,
@ -694,9 +719,12 @@ def sparse_column_with_vocabulary_file(column_name,
dtype=dtype) dtype=dtype)
class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple( class _WeightedSparseColumn(
"_WeightedSparseColumn", _FeatureColumn,
["sparse_id_column", "weight_column_name", "dtype"])): fc_core._CategoricalColumn, # pylint: disable=protected-access
collections.namedtuple("_WeightedSparseColumn",
["sparse_id_column", "weight_column_name",
"dtype"])):
"""See `weighted_sparse_column`.""" """See `weighted_sparse_column`."""
def __new__(cls, sparse_id_column, weight_column_name, dtype): def __new__(cls, sparse_id_column, weight_column_name, dtype):
@ -725,22 +753,6 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
"""Returns a string which will be used as a key when we do sorting.""" """Returns a string which will be used as a key when we do sorting."""
return "{}".format(self) return "{}".format(self)
def insert_transformed_feature(self, columns_to_tensors):
"""Inserts a tuple with the id and weight tensors."""
if self.sparse_id_column not in columns_to_tensors:
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
weight_tensor = columns_to_tensors[self.weight_column_name]
if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
# The weight tensor can be a regular Tensor. In such case, sparsify it.
weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
if not self.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
columns_to_tensors[self] = tuple([
columns_to_tensors[self.sparse_id_column],
weight_tensor
])
def id_tensor(self, input_tensor): def id_tensor(self, input_tensor):
"""Returns the id tensor from the given transformed input_tensor.""" """Returns the id tensor from the given transformed input_tensor."""
return input_tensor[0] return input_tensor[0]
@ -768,6 +780,43 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer(), initializer=init_ops.zeros_initializer(),
combiner=self.sparse_id_column.combiner) combiner=self.sparse_id_column.combiner)
def _do_transform(self, id_tensor, weight_tensor):
if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
# The weight tensor can be a regular Tensor. In such case, sparsify it.
weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
if not self.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
return tuple([id_tensor, weight_tensor])
def insert_transformed_feature(self, columns_to_tensors):
"""Inserts a tuple with the id and weight tensors."""
if self.sparse_id_column not in columns_to_tensors:
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
weight_tensor = columns_to_tensors[self.weight_column_name]
columns_to_tensors[self] = self._do_transform(
columns_to_tensors[self.sparse_id_column], weight_tensor)
def _transform_feature(self, inputs):
return self._do_transform(
inputs.get(self.sparse_id_column), inputs.get(self.weight_column_name))
@property
def _parse_example_config(self):
return self.config
@property
def _num_buckets(self):
return self.length
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
del weight_collections
del trainable
input_tensor = inputs.get(self)
return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
def weighted_sparse_column(sparse_id_column, def weighted_sparse_column(sparse_id_column,
weight_column_name, weight_column_name,
@ -815,9 +864,10 @@ def weighted_sparse_column(sparse_id_column,
return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype) return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype)
class _OneHotColumn(_FeatureColumn, class _OneHotColumn(
collections.namedtuple("_OneHotColumn", _FeatureColumn,
["sparse_id_column"])): fc_core._DenseColumn, # pylint: disable=protected-access
collections.namedtuple("_OneHotColumn", ["sparse_id_column"])):
"""Represents a one-hot column for use in deep networks. """Represents a one-hot column for use in deep networks.
Args: Args:
@ -897,12 +947,31 @@ class _OneHotColumn(_FeatureColumn,
return math_ops.reduce_sum( return math_ops.reduce_sum(
one_hot_id_tensor, reduction_indices=[output_rank - 1]) one_hot_id_tensor, reduction_indices=[output_rank - 1])
@property
def _variable_shape(self):
return tensor_shape.TensorShape((self.length))
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple( def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
"_EmbeddingColumn", del weight_collections
["sparse_id_column", "dimension", "combiner", "initializer", del trainable
"ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name", return inputs.get(self)
"shared_vocab_size", "max_norm", "trainable"])):
def _transform_feature(self, inputs):
return self._to_dnn_input_layer(inputs.get(self.sparse_id_column))
@property
def _parse_example_config(self):
return self.config
class _EmbeddingColumn(
_FeatureColumn,
fc_core._DenseColumn, # pylint: disable=protected-access
collections.namedtuple("_EmbeddingColumn", [
"sparse_id_column", "dimension", "combiner", "initializer",
"ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
"shared_vocab_size", "max_norm", "trainable"
])):
"""Represents an embedding column. """Represents an embedding column.
Args: Args:
@ -1027,6 +1096,139 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
raise ValueError("Column {} is not supported in linear models. " raise ValueError("Column {} is not supported in linear models. "
"Please use sparse_column.".format(self)) "Please use sparse_column.".format(self))
@property
def _variable_shape(self):
return tensor_shape.TensorShape((self.dimension))
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
return _embeddings_from_arguments(
self,
self._deep_embedding_lookup_arguments(inputs.get(self)),
weight_collections, trainable)
def _transform_feature(self, inputs):
return inputs.get(self.sparse_id_column)
@property
def _parse_example_config(self):
return self.config
def _is_variable(v):
"""Returns true if `v` is a variable."""
return isinstance(v, (variables.Variable,
resource_variable_ops.ResourceVariable))
def _embeddings_from_arguments(column,
args,
weight_collections,
trainable,
output_rank=2):
"""Returns embeddings for a column based on the computed arguments.
Args:
column: the column name.
args: the _DeepEmbeddingLookupArguments for this column.
weight_collections: collections to store weights in.
trainable: whether these embeddings should be trainable.
output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
be combined to produce the desired rank.
Returns:
the embeddings.
Raises:
ValueError: if not possible to create.
"""
# pylint: disable=protected-access
input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
weight_tensor = None
if args.weight_tensor is not None:
weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
# pylint: enable=protected-access
# This option is only enabled for scattered_embedding_column.
if args.hash_key:
embeddings = contrib_variables.model_variable(
name="weights",
shape=[args.vocab_size],
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=weight_collections)
return embedding_ops.scattered_embedding_lookup_sparse(
embeddings,
input_tensor,
args.dimension,
hash_key=args.hash_key,
combiner=args.combiner,
name="lookup")
if args.shared_embedding_name is not None:
shared_embedding_collection_name = (
"SHARED_EMBEDDING_COLLECTION_" + args.shared_embedding_name.upper())
graph = ops.get_default_graph()
shared_embedding_collection = (
graph.get_collection_ref(shared_embedding_collection_name))
shape = [args.vocab_size, args.dimension]
if shared_embedding_collection:
if len(shared_embedding_collection) > 1:
raise ValueError(
"Collection %s can only contain one "
"(partitioned) variable." % shared_embedding_collection_name)
else:
embeddings = shared_embedding_collection[0]
if embeddings.get_shape() != shape:
raise ValueError(
"The embedding variable with name {} already "
"exists, but its shape does not match required "
"embedding shape here. Please make sure to use "
"different shared_embedding_name for different "
"shared embeddings.".format(args.shared_embedding_name))
else:
embeddings = contrib_variables.model_variable(
name=args.shared_embedding_name,
shape=shape,
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=weight_collections)
graph.add_to_collection(shared_embedding_collection_name, embeddings)
else:
embeddings = contrib_variables.model_variable(
name="weights",
shape=[args.vocab_size, args.dimension],
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=weight_collections)
if _is_variable(embeddings):
embeddings = [embeddings]
else:
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
# pylint: disable=protected-access
_maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings)
return embedding_ops.safe_embedding_lookup_sparse(
embeddings,
input_tensor,
sparse_weights=weight_tensor,
combiner=args.combiner,
name=column.name + "weights",
max_norm=args.max_norm)
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
if checkpoint_path is not None:
path, tensor_name = checkpoint_path
weights_to_restore = variable
if len(variable) == 1:
weights_to_restore = variable[0]
checkpoint_utils.init_from_checkpoint(path,
{tensor_name: weights_to_restore})
def one_hot_column(sparse_id_column): def one_hot_column(sparse_id_column):
"""Creates an `_OneHotColumn` for a one-hot or multi-hot repr in a DNN. """Creates an `_OneHotColumn` for a one-hot or multi-hot repr in a DNN.

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import functools import functools
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import embedding_ops from tensorflow.contrib.layers.python.layers import embedding_ops
@ -34,118 +33,12 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
def _is_variable(v):
"""Returns true if `v` is a variable."""
return isinstance(v, (variables.Variable,
resource_variable_ops.ResourceVariable))
def _embeddings_from_arguments(column,
args,
weight_collections,
trainable,
output_rank=2):
"""Returns embeddings for a column based on the computed arguments.
Args:
column: the column name.
args: the _DeepEmbeddingLookupArguments for this column.
weight_collections: collections to store weights in.
trainable: whether these embeddings should be trainable.
output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
be combined to produce the desired rank.
Returns:
the embeddings.
Raises:
ValueError: if not possible to create.
"""
# pylint: disable=protected-access
input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
weight_tensor = None
if args.weight_tensor is not None:
weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
# pylint: enable=protected-access
# This option is only enabled for scattered_embedding_column.
if args.hash_key:
embeddings = contrib_variables.model_variable(
name='weights',
shape=[args.vocab_size],
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=weight_collections)
return embedding_ops.scattered_embedding_lookup_sparse(
embeddings, input_tensor, args.dimension,
hash_key=args.hash_key,
combiner=args.combiner, name='lookup')
if args.shared_embedding_name is not None:
shared_embedding_collection_name = (
'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper())
graph = ops.get_default_graph()
shared_embedding_collection = (
graph.get_collection_ref(shared_embedding_collection_name))
shape = [args.vocab_size, args.dimension]
if shared_embedding_collection:
if len(shared_embedding_collection) > 1:
raise ValueError('Collection %s can only contain one '
'(partitioned) variable.'
% shared_embedding_collection_name)
else:
embeddings = shared_embedding_collection[0]
if embeddings.get_shape() != shape:
raise ValueError('The embedding variable with name {} already '
'exists, but its shape does not match required '
'embedding shape here. Please make sure to use '
'different shared_embedding_name for different '
'shared embeddings.'.format(
args.shared_embedding_name))
else:
embeddings = contrib_variables.model_variable(
name=args.shared_embedding_name,
shape=shape,
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=weight_collections)
graph.add_to_collection(shared_embedding_collection_name, embeddings)
else:
embeddings = contrib_variables.model_variable(
name='weights',
shape=[args.vocab_size, args.dimension],
dtype=dtypes.float32,
initializer=args.initializer,
trainable=(trainable and args.trainable),
collections=weight_collections)
if _is_variable(embeddings):
embeddings = [embeddings]
else:
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
# pylint: disable=protected-access
_maybe_restore_from_checkpoint(
column._checkpoint_path(), embeddings)
return embedding_ops.safe_embedding_lookup_sparse(
embeddings,
input_tensor,
sparse_weights=weight_tensor,
combiner=args.combiner,
name=column.name + 'weights',
max_norm=args.max_norm)
def _maybe_reshape_input_tensor(tensor, column_name, output_rank): def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
"""Reshape the input tensor by the following rule. """Reshape the input tensor by the following rule.
@ -232,12 +125,13 @@ def _input_from_feature_columns(columns_to_tensors,
# pylint: disable=protected-access # pylint: disable=protected-access
arguments = column._deep_embedding_lookup_arguments( arguments = column._deep_embedding_lookup_arguments(
transformed_tensor) transformed_tensor)
output_tensors.append(_embeddings_from_arguments( output_tensors.append(
column, fc._embeddings_from_arguments( # pylint: disable=protected-access
arguments, column,
weight_collections, arguments,
trainable, weight_collections,
output_rank=output_rank)) trainable,
output_rank=output_rank))
except NotImplementedError as ee: except NotImplementedError as ee:
try: try:
@ -393,7 +287,7 @@ def _create_embedding_lookup(column,
initializer=embedding_lookup_arguments.initializer, initializer=embedding_lookup_arguments.initializer,
trainable=trainable, trainable=trainable,
collections=weight_collections) collections=weight_collections)
if _is_variable(variable): if fc._is_variable(variable): # pylint: disable=protected-access
variable = [variable] variable = [variable]
else: else:
variable = variable._get_variable_list() # pylint: disable=protected-access variable = variable._get_variable_list() # pylint: disable=protected-access
@ -406,16 +300,6 @@ def _create_embedding_lookup(column,
return variable, predictions return variable, predictions
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
if checkpoint_path is not None:
path, tensor_name = checkpoint_path
weights_to_restore = variable
if len(variable) == 1:
weights_to_restore = variable[0]
checkpoint_utils.init_from_checkpoint(path,
{tensor_name: weights_to_restore})
def _create_joint_embedding_lookup(columns_to_tensors, def _create_joint_embedding_lookup(columns_to_tensors,
embedding_lookup_arguments, embedding_lookup_arguments,
num_outputs, num_outputs,
@ -451,7 +335,7 @@ def _create_joint_embedding_lookup(columns_to_tensors,
initializer=init_ops.zeros_initializer(), initializer=init_ops.zeros_initializer(),
trainable=trainable, trainable=trainable,
collections=weight_collections) collections=weight_collections)
if _is_variable(variable): if fc._is_variable(variable): # pylint: disable=protected-access
variable = [variable] variable = [variable]
else: else:
variable = variable._get_variable_list() # pylint: disable=protected-access variable = variable._get_variable_list() # pylint: disable=protected-access
@ -634,7 +518,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
predictions, shape=(-1, num_outputs))) predictions, shape=(-1, num_outputs)))
column_to_variable[column] = variable column_to_variable[column] = variable
_log_variable(variable) _log_variable(variable)
_maybe_restore_from_checkpoint(column._checkpoint_path(), variable) fc._maybe_restore_from_checkpoint(column._checkpoint_path(), variable) # pylint: disable=protected-access
# pylint: enable=protected-access # pylint: enable=protected-access
predictions_no_bias = math_ops.add_n(output_tensors) predictions_no_bias = math_ops.add_n(output_tensors)
bias = contrib_variables.model_variable( bias = contrib_variables.model_variable(
@ -827,10 +711,10 @@ def parse_feature_columns_from_sequence_examples(
def _log_variable(variable): def _log_variable(variable):
if isinstance(variable, list): if isinstance(variable, list):
for var in variable: for var in variable:
if _is_variable(variable): if fc._is_variable(variable): # pylint: disable=protected-access
logging.info('Created variable %s, with device=%s', var.name, logging.info('Created variable %s, with device=%s', var.name,
var.device) var.device)
elif _is_variable(variable): elif fc._is_variable(variable): # pylint: disable=protected-access
logging.info('Created variable %s, with device=%s', variable.name, logging.info('Created variable %s, with device=%s', variable.name,
variable.device) variable.device)

View File

@ -597,12 +597,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
"income": "income":
constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]), constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]),
} }
output = feature_column_ops.input_from_feature_columns(features, [ columns = [one_hot_column, embedding_column, real_valued_column]
one_hot_column, embedding_column, real_valued_column]) output = feature_column_ops.input_from_feature_columns(features, columns)
output_core = fc_core.make_input_layer(features, columns)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10]) self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
def testRealValuedColumn(self): def testRealValuedColumn(self):
real_valued = feature_column.real_valued_column("price") real_valued = feature_column.real_valued_column("price")
@ -712,11 +715,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_column = feature_column.one_hot_column(weighted_ids_column) one_hot_column = feature_column.one_hot_column(weighted_ids_column)
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[one_hot_column]) [one_hot_column])
output_core = fc_core.make_input_layer(features, [one_hot_column])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]], self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
output.eval()) output.eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromSparseColumnWithKeysSucceedsForDNN(self): def testOneHotColumnFromSparseColumnWithKeysSucceedsForDNN(self):
ids_column = feature_column.sparse_column_with_keys( ids_column = feature_column.sparse_column_with_keys(
@ -729,12 +735,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"ids": ids_tensor} features = {"ids": ids_tensor}
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse]) [one_hot_sparse])
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]], self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
output.eval()) output.eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromMultivalentSparseColumnWithKeysSucceedsForDNN(self): def testOneHotColumnFromMultivalentSparseColumnWithKeysSucceedsForDNN(self):
ids_column = feature_column.sparse_column_with_keys( ids_column = feature_column.sparse_column_with_keys(
@ -747,12 +756,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"ids": ids_tensor} features = {"ids": ids_tensor}
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse]) [one_hot_sparse])
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval()) output.eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromSparseColumnWithIntegerizedFeaturePassesForDNN(self): def testOneHotColumnFromSparseColumnWithIntegerizedFeaturePassesForDNN(self):
ids_column = feature_column.sparse_column_with_integerized_feature( ids_column = feature_column.sparse_column_with_integerized_feature(
@ -767,10 +779,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
} }
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse]) [one_hot_sparse])
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]], self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval()) output.eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromSparseColumnWithHashBucketSucceedsForDNN(self): def testOneHotColumnFromSparseColumnWithHashBucketSucceedsForDNN(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("feat", 10) hashed_sparse = feature_column.sparse_column_with_hash_bucket("feat", 10)
@ -782,10 +797,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_sparse = feature_column.one_hot_column(hashed_sparse) one_hot_sparse = feature_column.one_hot_column(hashed_sparse)
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse]) [one_hot_sparse])
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape) self.assertAllEqual([3, 10], output.eval().shape)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval(), output_core.eval())
def testEmbeddingColumnSucceedsForDNN(self): def testEmbeddingColumnSucceedsForDNN(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10) hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
@ -797,9 +815,12 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10) embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse]) [embeded_sparse])
output_core = fc_core.make_input_layer(features, [embeded_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [4, 10]) self.assertAllEqual(output.eval().shape, [4, 10])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
def testScatteredEmbeddingColumnSucceedsForDNN(self): def testScatteredEmbeddingColumnSucceedsForDNN(self):
wire_tensor = sparse_tensor.SparseTensor( wire_tensor = sparse_tensor.SparseTensor(
@ -838,12 +859,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.constant_initializer(init_value)) initializer=init_ops.constant_initializer(init_value))
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse]) [embeded_sparse])
output_core = fc_core.make_input_layer(features, [embeded_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
output_eval = output.eval() output_eval = output.eval()
self.assertAllEqual(output_eval.shape, [2, 10]) self.assertAllEqual(output_eval.shape, [2, 10])
self.assertAllClose(output_eval, np.tile(init_value, [2, 10])) self.assertAllClose(output_eval, np.tile(init_value, [2, 10]))
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval(), output_core.eval())
def testEmbeddingColumnWithMultipleInitializersFails(self): def testEmbeddingColumnWithMultipleInitializersFails(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10) hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
@ -889,10 +913,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(weighted_ids, 10) embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
output = feature_column_ops.input_from_feature_columns(features, output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse]) [embeded_sparse])
output_core = fc_core.make_input_layer(features, [embeded_sparse])
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10]) self.assertAllEqual(output.eval().shape, [2, 10])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self): def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
"""Same as the previous test, but with integer weights.""" """Same as the previous test, but with integer weights."""
@ -1534,9 +1562,12 @@ class WeightedSumTest(test.TestCase):
features = {"wire": wire_tensor} features = {"wire": wire_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5) features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(logits.eval(), logits_core.eval())
def testSparseIntColumn(self): def testSparseIntColumn(self):
"""Tests a sparse column with int values.""" """Tests a sparse column with int values."""
@ -1549,9 +1580,12 @@ class WeightedSumTest(test.TestCase):
features = {"wire": wire_tensor} features = {"wire": wire_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5) features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(logits.eval(), logits_core.eval())
def testSparseColumnWithDenseInputTensor(self): def testSparseColumnWithDenseInputTensor(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10) hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
@ -1560,9 +1594,12 @@ class WeightedSumTest(test.TestCase):
features = {"wire": wire_tensor} features = {"wire": wire_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5) features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(logits.eval(), logits_core.eval())
def testWeightedSparseColumn(self): def testWeightedSparseColumn(self):
ids = feature_column.sparse_column_with_keys("ids", ids = feature_column.sparse_column_with_keys("ids",
@ -1579,10 +1616,13 @@ class WeightedSumTest(test.TestCase):
features = {"ids": ids_tensor, "weights": weights_tensor} features = {"ids": ids_tensor, "weights": weights_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5) features, [weighted_ids], num_outputs=5)
logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(logits.eval(), logits_core.eval())
def testWeightedSparseColumnWithDenseInputTensor(self): def testWeightedSparseColumnWithDenseInputTensor(self):
ids = feature_column.sparse_column_with_keys( ids = feature_column.sparse_column_with_keys(
@ -1594,11 +1634,14 @@ class WeightedSumTest(test.TestCase):
features = {"ids": ids_tensor, "weights": weights_tensor} features = {"ids": ids_tensor, "weights": weights_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns( logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5) features, [weighted_ids], num_outputs=5)
logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
with self.test_session(): with self.test_session():
variables_lib.global_variables_initializer().run() variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5]) self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(logits.eval(), logits_core.eval())
def testCrossedColumn(self): def testCrossedColumn(self):
a = feature_column.sparse_column_with_hash_bucket( a = feature_column.sparse_column_with_hash_bucket(
@ -1649,6 +1692,8 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = ( output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns( feature_column_ops.weighted_sum_from_feature_columns(
features, [movies], num_outputs=1)) features, [movies], num_outputs=1))
logits_core = fc_core.make_linear_model(features, [movies])
with self.test_session() as sess: with self.test_session() as sess:
variables_lib.initialize_all_variables().run() variables_lib.initialize_all_variables().run()
lookup_ops.tables_initializer().run() lookup_ops.tables_initializer().run()
@ -1659,6 +1704,8 @@ class WeightedSumTest(test.TestCase):
# score for first example = 0.3 (matrix) + 0.1 (head-on) = 0.4 # score for first example = 0.3 (matrix) + 0.1 (head-on) = 0.4
# score for second example = 0.5 (winter sleep) # score for second example = 0.5 (winter sleep)
self.assertAllClose(output.eval(), [[0.4], [0.5]]) self.assertAllClose(output.eval(), [[0.4], [0.5]])
# Cross compatibility: Core builder output should equal to contrib.
self.assertAllEqual(output.eval().shape, logits_core.eval().shape)
def testRealValuedColumnWithMultiDimensions(self): def testRealValuedColumnWithMultiDimensions(self):
real_valued = feature_column.real_valued_column("price", 2) real_valued = feature_column.real_valued_column("price", 2)

View File

@ -36,7 +36,8 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
Xavier Glorot and Yoshua Bengio (2010): Xavier Glorot and Yoshua Bengio (2010):
[Understanding the difficulty of training deep feedforward neural [Understanding the difficulty of training deep feedforward neural
networks. International conference on artificial intelligence and networks. International conference on artificial intelligence and
statistics.](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.207.2059&rep=rep1&type=pdf) statistics.](
http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
This initializer is designed to keep the scale of the gradients roughly the This initializer is designed to keep the scale of the gradients roughly the
same in all layers. In uniform distribution this ends up being the range: same in all layers. In uniform distribution this ends up being the range:

View File

@ -102,9 +102,10 @@ def _linear_learning_rate(num_linear_feature_columns):
def _add_hidden_layer_summary(value, tag): def _add_hidden_layer_summary(value, tag):
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value)) summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
summary.histogram("%s/activation" % tag, value) summary.histogram("%s/activation" % tag, value)
def _add_layer_summary(value, tag): def _add_layer_summary(value, tag):
summary.scalar("%s/fraction_of_zero_values" % tag, summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
nn.zero_fraction(value))
summary.histogram("%s/activation" % tag, value) summary.histogram("%s/activation" % tag, value)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib import layers from tensorflow.contrib import layers
from tensorflow.contrib.framework.python.framework import deprecated
from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import estimator

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.summary import summary from tensorflow.python.summary import summary
from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.ops.control_flow_ops import with_dependencies
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.session_run_hook import SessionRunArgs

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib import layers from tensorflow.contrib import layers
from tensorflow.contrib import rnn as rnn_cell from tensorflow.contrib import rnn as rnn_cell
from tensorflow.contrib.framework.python.framework import deprecated
from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import constants

View File

@ -455,6 +455,7 @@ class LegacyConstructorTest(test.TestCase):
return {'inputs': inputs}, labels return {'inputs': inputs}, labels
return input_fn return input_fn
# TODO(jtbates): move all tests below to a benchmark test. # TODO(jtbates): move all tests below to a benchmark test.
class StateSavingRNNEstimatorLearningTest(test.TestCase): class StateSavingRNNEstimatorLearningTest(test.TestCase):
"""Learning tests for state saving RNN Estimators.""" """Learning tests for state saving RNN Estimators."""

View File

@ -22,6 +22,7 @@ import os
import tempfile import tempfile
import time import time
from tensorflow.contrib.learn.python.learn import estimator as estimator_lib
from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import experiment from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn import run_config from tensorflow.contrib.learn.python.learn import run_config
@ -38,6 +39,7 @@ from tensorflow.python.training import saver
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import tf_inspect
class SheepCounter(object): class SheepCounter(object):
@ -119,6 +121,15 @@ class TestBaseEstimator(object):
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp')) compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
def _check_method_supports_args(method, kwargs):
"""Checks that the given method supports the given args."""
supported_args = tuple(tf_inspect.getargspec(method).args)
for kwarg in kwargs:
if kwarg not in supported_args:
raise ValueError(
'Argument `{}` is not supported in method {}.'.format(kwarg, method))
class TestEstimator( class TestEstimator(
TestBaseEstimator, evaluable.Evaluable, trainable.Trainable): TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
@ -126,9 +137,12 @@ class TestEstimator(
super(TestEstimator, self).__init__(config, max_evals, eval_dict) super(TestEstimator, self).__init__(config, max_evals, eval_dict)
tf_logging.info('Create Estimator') tf_logging.info('Create Estimator')
def evaluate(self, **kwargs):
_check_method_supports_args(evaluable.Evaluable.evaluate, kwargs)
return super(TestEstimator, self).evaluate(**kwargs)
def fit(self, **kwargs): def fit(self, **kwargs):
if 'hooks' in kwargs: _check_method_supports_args(trainable.Trainable.fit, kwargs)
raise ValueError('`hooks` is defined in core Estimator')
if 'monitors' in kwargs: if 'monitors' in kwargs:
self.monitors = kwargs['monitors'] self.monitors = kwargs['monitors']
return super(TestEstimator, self).train(**kwargs) return super(TestEstimator, self).train(**kwargs)
@ -136,6 +150,13 @@ class TestEstimator(
def train(self, **kwargs): def train(self, **kwargs):
raise ValueError('`train` is not defined in Estimator.') raise ValueError('`train` is not defined in Estimator.')
def export_savedmodel(
self, export_dir_base, serving_input_fn, **kwargs):
_check_method_supports_args(
estimator_lib.Estimator.export_savedmodel, kwargs)
return super(TestEstimator, self).export_savedmodel(
export_dir_base, serving_input_fn, **kwargs)
class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator): class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
@ -144,17 +165,22 @@ class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
tf_logging.info('Create Core Estimator') tf_logging.info('Create Core Estimator')
def evaluate(self, **kwargs): def evaluate(self, **kwargs):
if 'eval_metrics' in kwargs: _check_method_supports_args(core_estimator.Estimator.evaluate, kwargs)
raise ValueError('`eval_metrics` is not defined in core Estimator')
return super(TestCoreEstimator, self).evaluate(**kwargs) return super(TestCoreEstimator, self).evaluate(**kwargs)
def train(self, **kwargs): def train(self, **kwargs):
if 'monitors' in kwargs: _check_method_supports_args(core_estimator.Estimator.train, kwargs)
raise ValueError('`monitors` is not defined in core Estimator')
if 'hooks' in kwargs: if 'hooks' in kwargs:
self.monitors = kwargs['hooks'] self.monitors = kwargs['hooks']
return super(TestCoreEstimator, self).train(**kwargs) return super(TestCoreEstimator, self).train(**kwargs)
def export_savedmodel(
self, export_dir_base, serving_input_receiver_fn, **kwargs):
_check_method_supports_args(
core_estimator.Estimator.export_savedmodel, kwargs)
return super(TestCoreEstimator, self).export_savedmodel(
export_dir_base, serving_input_receiver_fn, **kwargs)
class _NoopHook(session_run_hook.SessionRunHook): class _NoopHook(session_run_hook.SessionRunHook):
pass pass
@ -184,6 +210,23 @@ class ExperimentTest(test.TestCase):
eval_input_fn='eval_input', eval_input_fn='eval_input',
eval_metrics='eval_metrics') eval_metrics='eval_metrics')
def test_default_output_alternative_key_core_estimator(self):
est = TestCoreEstimator()
export_strategy = saved_model_export_utils.make_export_strategy(
est,
default_output_alternative_key='export_key',
exports_to_keep=None)
ex = experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
train_steps=100,
eval_steps=100,
export_strategies=export_strategy)
with self.assertRaisesRegexp(
ValueError, 'default_output_alternative_key is not supported'):
ex.train_and_evaluate()
def test_train(self): def test_train(self):
for est in self._estimators_for_tests(): for est in self._estimators_for_tests():
eval_metrics = 'eval_metrics' if not isinstance( eval_metrics = 'eval_metrics' if not isinstance(
@ -508,7 +551,9 @@ class ExperimentTest(test.TestCase):
eval_metrics = 'eval_metrics' if not isinstance( eval_metrics = 'eval_metrics' if not isinstance(
est, core_estimator.Estimator) else None est, core_estimator.Estimator) else None
export_strategy_1 = saved_model_export_utils.make_export_strategy( export_strategy_1 = saved_model_export_utils.make_export_strategy(
est, 'export_input_1', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_1',
exports_to_keep=None)
ex = experiment.Experiment( ex = experiment.Experiment(
est, est,
@ -531,9 +576,13 @@ class ExperimentTest(test.TestCase):
# After reset with list, the count should increase with the number of # After reset with list, the count should increase with the number of
# items. # items.
export_strategy_2 = saved_model_export_utils.make_export_strategy( export_strategy_2 = saved_model_export_utils.make_export_strategy(
est, 'export_input_2', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_2',
exports_to_keep=None)
export_strategy_3 = saved_model_export_utils.make_export_strategy( export_strategy_3 = saved_model_export_utils.make_export_strategy(
est, 'export_input_3', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_3',
exports_to_keep=None)
old_es = ex.reset_export_strategies( old_es = ex.reset_export_strategies(
[export_strategy_2, export_strategy_3]) [export_strategy_2, export_strategy_3])
@ -547,7 +596,9 @@ class ExperimentTest(test.TestCase):
est, core_estimator.Estimator) else None est, core_estimator.Estimator) else None
noop_hook = _NoopHook() noop_hook = _NoopHook()
export_strategy = saved_model_export_utils.make_export_strategy( export_strategy = saved_model_export_utils.make_export_strategy(
est, 'export_input', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_input',
exports_to_keep=None)
ex = experiment.Experiment( ex = experiment.Experiment(
est, est,
train_input_fn='train_input', train_input_fn='train_input',
@ -625,7 +676,9 @@ class ExperimentTest(test.TestCase):
est, core_estimator.Estimator) else None est, core_estimator.Estimator) else None
noop_hook = _NoopHook() noop_hook = _NoopHook()
export_strategy = saved_model_export_utils.make_export_strategy( export_strategy = saved_model_export_utils.make_export_strategy(
est, 'export_input', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_input',
exports_to_keep=None)
ex = experiment.Experiment( ex = experiment.Experiment(
est, est,
train_input_fn='train_input', train_input_fn='train_input',
@ -646,7 +699,9 @@ class ExperimentTest(test.TestCase):
eval_metrics = 'eval_metrics' if not isinstance( eval_metrics = 'eval_metrics' if not isinstance(
est, core_estimator.Estimator) else None est, core_estimator.Estimator) else None
export_strategy = saved_model_export_utils.make_export_strategy( export_strategy = saved_model_export_utils.make_export_strategy(
est, 'export_input', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_input',
exports_to_keep=None)
ex = experiment.Experiment( ex = experiment.Experiment(
est, est,
train_input_fn='train_input', train_input_fn='train_input',
@ -796,7 +851,9 @@ class ExperimentTest(test.TestCase):
def test_test(self): def test_test(self):
for est in self._estimators_for_tests(): for est in self._estimators_for_tests():
exp_strategy = saved_model_export_utils.make_export_strategy( exp_strategy = saved_model_export_utils.make_export_strategy(
est, 'export_input', exports_to_keep=None) est,
None if isinstance(est, core_estimator.Estimator) else 'export_input',
exports_to_keep=None)
ex = experiment.Experiment( ex = experiment.Experiment(
est, est,
train_input_fn='train_input', train_input_fn='train_input',

View File

@ -42,6 +42,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import gc from tensorflow.contrib.learn.python.learn.utils import gc
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
@ -352,7 +353,8 @@ def make_export_strategy(serving_input_fn,
`InputFnOps`. `InputFnOps`.
default_output_alternative_key: the name of the head to serve when an default_output_alternative_key: the name of the head to serve when an
incoming serving request does not explicitly request a specific head. incoming serving request does not explicitly request a specific head.
Not needed for single-headed models. Must be `None` if the estimator inherits from ${tf.estimator.Estimator}
or for single-headed models.
assets_extra: A dict specifying how to populate the assets.extra directory assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel. Each key should give the destination within the exported SavedModel. Each key should give the destination
path (including the filename) relative to the assets.extra directory. path (including the filename) relative to the assets.extra directory.
@ -384,14 +386,30 @@ def make_export_strategy(serving_input_fn,
Returns: Returns:
The string path to the exported directory. The string path to the exported directory.
Raises:
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
and `default_output_alternative_key` was specified.
""" """
export_result = estimator.export_savedmodel( if isinstance(estimator, core_estimator.Estimator):
export_dir_base, if default_output_alternative_key is not None:
serving_input_fn, raise ValueError(
default_output_alternative_key=default_output_alternative_key, 'default_output_alternative_key is not supported in core '
assets_extra=assets_extra, 'Estimator. Given: {}'.format(default_output_alternative_key))
as_text=as_text, export_result = estimator.export_savedmodel(
checkpoint_path=checkpoint_path) export_dir_base,
serving_input_fn,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path)
else:
export_result = estimator.export_savedmodel(
export_dir_base,
serving_input_fn,
default_output_alternative_key=default_output_alternative_key,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path)
garbage_collect_exports(export_dir_base, exports_to_keep) garbage_collect_exports(export_dir_base, exports_to_keep)
return export_result return export_result

View File

@ -1,9 +1,9 @@
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"]) exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
py_library( py_library(

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""##Signal ops.
"""
@@frames @@frames
""" """

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Signal ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division

View File

@ -33,34 +33,34 @@ class FramesTest(test.TestCase):
with self.test_session(): with self.test_session():
tensor = constant_op.constant(np.arange(9152), dtypes.int32) tensor = constant_op.constant(np.arange(9152), dtypes.int32)
tensor = array_ops.expand_dims(tensor, 0) tensor = array_ops.expand_dims(tensor, 0)
result = shape_ops.frames(tensor, 512, 180) result = shape_ops.frames(tensor, 512, 180)
result = result.eval() result = result.eval()
expected = np.tile(np.arange(512), (49, 1)) expected = np.tile(np.arange(512), (49, 1))
expected += np.tile(np.arange(49) * 180, (512, 1)).T expected += np.tile(np.arange(49) * 180, (512, 1)).T
expected = np.expand_dims(expected, axis=0) expected = np.expand_dims(expected, axis=0)
expected = np.array(expected, dtype=np.int32) expected = np.array(expected, dtype=np.int32)
self.assertAllEqual(expected, result) self.assertAllEqual(expected, result)
def test_mapping_of_indices_with_padding(self): def test_mapping_of_indices_with_padding(self):
with self.test_session(): with self.test_session():
tensor = constant_op.constant(np.arange(10000), dtypes.int32) tensor = constant_op.constant(np.arange(10000), dtypes.int32)
tensor = array_ops.expand_dims(tensor, 0) tensor = array_ops.expand_dims(tensor, 0)
result = shape_ops.frames(tensor, 512, 192) result = shape_ops.frames(tensor, 512, 192)
result = result.eval() result = result.eval()
expected = np.tile(np.arange(512), (51, 1)) expected = np.tile(np.arange(512), (51, 1))
expected += np.tile(np.arange(51) * 192, (512, 1)).T expected += np.tile(np.arange(51) * 192, (512, 1)).T
expected[expected >= 10000] = 0 expected[expected >= 10000] = 0
expected = np.expand_dims(expected, axis=0) expected = np.expand_dims(expected, axis=0)
expected = np.array(expected, dtype=np.int32) expected = np.array(expected, dtype=np.int32)
self.assertAllEqual(expected, result) self.assertAllEqual(expected, result)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Signal ops."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""General shape ops for frames."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -23,59 +24,64 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
def frames(signal, frame_length, frame_step, name=None): def frames(signal, frame_length, frame_step, name=None):
"""Frame a signal into overlapping frames. """Frame a signal into overlapping frames.
May be used in front of spectral functions. May be used in front of spectral functions.
For example: For example:
```python ```python
pcm = tf.placeholder(tf.float32, [None, 9152]) pcm = tf.placeholder(tf.float32, [None, 9152])
frames = tf.contrib.signal.frames(pcm, 512, 180) frames = tf.contrib.signal.frames(pcm, 512, 180)
magspec = tf.abs(tf.spectral.rfft(frames, [512])) magspec = tf.abs(tf.spectral.rfft(frames, [512]))
image = tf.expand_dims(magspec, 3) image = tf.expand_dims(magspec, 3)
``` ```
Args: Args:
signal: A `Tensor` of shape `[batch_size, signal_length]`. signal: A `Tensor` of shape `[batch_size, signal_length]`.
frame_length: An `int32` or `int64` `Tensor`. The length of each frame. frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
frame_step: An `int32` or `int64` `Tensor`. The step between frames. frame_step: An `int32` or `int64` `Tensor`. The step between frames.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`. A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
Raises:
ValueError: if signal does not have rank 2.
""" """
with ops.name_scope(name, "frames", [signal, frame_length, frame_step]): with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
signal = ops.convert_to_tensor(signal, name="signal") signal = ops.convert_to_tensor(signal, name="signal")
frame_length = ops.convert_to_tensor(frame_length, name="frame_length") frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
frame_step = ops.convert_to_tensor(frame_step, name="frame_step") frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
signal_rank = signal.shape.ndims signal_rank = signal.shape.ndims
if signal_rank != 2: if signal_rank != 2:
raise ValueError("expected signal to have rank 2 but was " + signal_rank) raise ValueError("expected signal to have rank 2 but was " + signal_rank)
signal_length = array_ops.shape(signal)[1] signal_length = array_ops.shape(signal)[1]
num_frames = math_ops.ceil((signal_length - frame_length) / frame_step) num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
num_frames = 1 + math_ops.cast(num_frames, dtypes.int32) num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
pad_length = (num_frames - 1) * frame_step + frame_length pad_length = (num_frames - 1) * frame_step + frame_length
pad_signal = array_ops.pad( pad_signal = array_ops.pad(signal, [[0, 0], [0,
signal, [[0, 0], [0, pad_length - signal_length]]) pad_length - signal_length]])
indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0) indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
indices_frames = array_ops.tile(indices_frame, [num_frames, 1]) indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
indices_step = array_ops.expand_dims( indices_step = array_ops.expand_dims(
math_ops.range(num_frames) * frame_step, 1) math_ops.range(num_frames) * frame_step, 1)
indices_steps = array_ops.tile(indices_step, [1, frame_length]) indices_steps = array_ops.tile(indices_step, [1, frame_length])
indices = indices_frames + indices_steps indices = indices_frames + indices_steps
# TODO(Androbin): remove `transpose` when `gather` gets `axis` support # TODO(androbin): remove `transpose` when `gather` gets `axis` support
pad_signal = array_ops.transpose(pad_signal) pad_signal = array_ops.transpose(pad_signal)
frames = array_ops.gather(pad_signal, indices) signal_frames = array_ops.gather(pad_signal, indices)
frames = array_ops.transpose(frames, perm=[2, 0, 1]) signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
return frames return signal_frames

View File

@ -127,6 +127,6 @@ class FakeSummaryWriter(object):
def reopen(self): def reopen(self):
pass pass
def close(self): def close(self):
pass pass

View File

@ -97,6 +97,29 @@ py_test(
], ],
) )
py_library(
name = "pprof_profiler",
srcs = ["pprof_profiler.py"],
srcs_version = "PY2AND3",
deps = ["@pprof_profile_proto//:pprof_proto_py"],
)
py_test(
name = "pprof_profiler_test",
srcs = ["pprof_profiler_test.py"],
main = "pprof_profiler_test.py",
srcs_version = "PY2AND3",
tags = ["no_pip"], # TODO(annarev): get it working with pip.
deps = [
":pprof_profiler",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"@pprof_profile_proto//:pprof_proto_py",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo. # Google-internal targets. These must be at the end for syncrepo.

View File

@ -0,0 +1,445 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Profiler for TensorFlow models that outputs data in pprof format.
See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
profile format.
The following needs to be set for profiler to work:
* trace_level needs to be set to FULL_TRACE
* run_metadata object should be passed in to session.run call
Sample usage:
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
with tf.Session as sess:
...
sess.run(computation, run_metadata=run_metadata, options=options)
pprof_profiler.profile(sess.graph, run_metadata, output_dir)
The code above would output a pprof profile to separate output_dir/.*.pb.gz
file for each device. These files can be passed to pprof for formatting.
For e.g.:
pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
from collections import namedtuple
import gzip
import os
import string
import sys
import time
from pprof import profile_pb2
if sys.version_info < (3,):
maketrans = string.maketrans
else:
maketrans = str.maketrans
ProfileDatum = namedtuple('ProfileDatum', [
'node_exec_stats', 'op_type', 'traceback'])
class StringTable(object):
"""Keeps track of strings to add to string_table in pprof proto."""
def __init__(self):
# Pprof requires first entry in string_table to be ''.
self._string_table = ['']
self._string_to_index = {'': 0}
def index_of(self, value_str):
"""Get index of value_str in the string table.
If value_str is not in the string table, we will add it at the end
and then return the new index.
Args:
value_str: (string) Value to lookup/add in/to the string table.
Returns:
Index of value_str in the string table.
"""
if value_str is None:
value_str = ''
if value_str in self._string_to_index:
return self._string_to_index[value_str]
index = len(self._string_table)
self._string_table.append(value_str)
self._string_to_index[value_str] = index
return index
def next_index(self):
"""Gets index that would be assigned to the next added string.
Returns:
Index of the next string if it was added.
"""
return len(self._string_table)
def string_table(self):
"""Returns a list of strings to store in pprof's string_table."""
return self._string_table
class Functions(object):
"""Keeps track of `Function` protos for pprof profile."""
def __init__(self, string_table):
"""Constructor.
Args:
string_table: A `StringTable` object.
"""
self._string_table = string_table
# Maps tuples in the form (file_path, function_name, start_line_number)
# to `Function` protos.
self._function_key_to_function = {}
def index_of(self, file_path, function_name, function_start_line):
"""Returns index of the function, adding the function if needed.
Args:
file_path: (string) Path to file where the function is defined.
function_name: (string) Function name.
function_start_line: (integer) Start line number of function definition.
Returns:
Function index.
"""
function_key = (file_path, function_name, function_start_line)
if function_key in self._function_key_to_function:
return self._function_key_to_function[function_key].id
else:
# Function indexes should start from 1
function_index = len(self._function_key_to_function) + 1
function = profile_pb2.Function()
function.id = function_index
function.name = self._string_table.index_of(function_name)
function.filename = self._string_table.index_of(file_path)
function.start_line = function_start_line
self._function_key_to_function[function_key] = function
return function_index
def function_protos(self):
"""Returns list of `profile_pb2.Function` protos."""
return self._function_key_to_function.values()
class Locations(object):
"""Keeps track of `Location` protos for pprof profile.
`Locations` store information about function call locations.
"""
def __init__(self, functions):
"""Constructor.
Args:
functions: A `Functions` object.
"""
self._functions = functions
# Maps tuples in the form (file_path, called_function_name, line_number)
# to `Location` protos.
self._location_key_to_location = {}
def index_of(
self, file_path, line_number, called_function_name, called_file_path,
called_function_start_line):
"""Returns index of the location, adding the location if needed.
Args:
file_path: (string) Path to file that makes the call.
line_number: (integer) Call line number.
called_function_name: (string) Function name of the function called at
`file_path` and `line_number`.
called_file_path: (string) Path to file where the called function is
defined.
called_function_start_line: (integer) Start line number of called
function definition in `called_file_path` file.
Returns:
Index of location.
"""
location_key = (file_path, called_function_name, line_number)
if location_key in self._location_key_to_location:
location = self._location_key_to_location[location_key]
return location.id
else:
# Location indexes should start from 1
location_index = len(self._location_key_to_location) + 1
location = profile_pb2.Location()
location.id = location_index
self._location_key_to_location[location_key] = location
line = location.line.add()
line.function_id = self._functions.index_of(
called_file_path, called_function_name, called_function_start_line)
line.line = line_number
return location_index
def location_protos(self):
"""Returns list of `profile_pb2.Location` protos."""
return self._location_key_to_location.values()
class Samples(object):
"""Keeps track of `Sample` protos for pprof profile.
Samples store the following statistics in order:
count, all_time, op_time
"""
def __init__(self, string_table):
"""Constructor.
Args:
string_table: A `StringTable` object.
"""
self._string_table = string_table
# TODO(annarev): figure out if location is unique for each node name.
# If not, also key this dictionary based on location ids.
self._node_name_to_sample = {}
def add(self, datum, location_ids):
"""Adds a sample data point.
Args:
datum: `ProfileDatum` to add a sample for.
location_ids: List of numberic location ids for this
sample.
"""
node_name = datum.node_exec_stats.node_name
if node_name in self._node_name_to_sample:
sample = self._node_name_to_sample[node_name]
sample.location_id.extend(location_ids)
else:
sample = profile_pb2.Sample()
# Sample stores 3 values: count, all_time, op_time
sample.value.extend([0, 0, 0])
label = sample.label.add()
label.key = self._string_table.index_of('node_name')
label.str = self._string_table.index_of(node_name)
label = sample.label.add()
label.key = self._string_table.index_of('op_type')
label.str = self._string_table.index_of(datum.op_type)
self._node_name_to_sample[node_name] = sample
sample.value[0] += 1
sample.value[1] += datum.node_exec_stats.all_end_rel_micros
sample.value[2] += (
datum.node_exec_stats.op_end_rel_micros -
datum.node_exec_stats.op_start_rel_micros)
def get_sample_protos(self):
"""Returns list of `Sample` protos for pprof profile."""
return self._node_name_to_sample.values()
class PprofProfiler(object):
"""Creates profiles in pprof format."""
def __init__(self, graph, run_metadata):
"""Constructor.
Args:
graph: A `Graph` instance.
run_metadata: A list of `RunMetadata` objects.
"""
self._graph = graph
self._run_metadata = run_metadata
self._string_table = StringTable()
self._functions = Functions(self._string_table)
self._locations = Locations(self._functions)
def profile(self):
"""Generates pprof profiles.
Returns:
Dictionary mapping from device name to proto in `profile_pb2.Profile`
format.
"""
profiles = {}
data_generator_func = self._get_profile_data_generator()
for device_index, device_stats in enumerate(
self._run_metadata.step_stats.dev_stats):
# Create profile
pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
if not pprof_proto.sample:
print(
'Not enough data to create profile for device %s. Did you pass '
'RunMetadata to session.run call?' % device_stats.device)
continue
# Add device name comment
device_count = len(self._run_metadata.step_stats.dev_stats)
device_description = (
'Device %d of %d: %s' %
(device_index + 1, device_count, device_stats.device))
device_description_str_index = self._string_table.next_index()
pprof_proto.string_table.append(device_description)
pprof_proto.comment.append(device_description_str_index)
profiles[device_stats.device] = pprof_proto
return profiles
def _get_pprof_proto(self, profile_datum_generator):
"""Returns profile data in pprof proto format.
Args:
profile_datum_generator: Generator outputting `ProfileDatum` objects.
Returns:
A proto in pprof format.
"""
pprof_profile = profile_pb2.Profile()
samples = Samples(self._string_table)
for datum in profile_datum_generator:
if not datum.traceback:
continue
stack_frame = datum.traceback[-1]
after_apply_op = False
location_ids = []
# We add locations from stack trace in bottom-up order.
for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
prev_stack_frame = stack_frame
stack_frame = datum.traceback[stack_frame_index]
# Call at current frame calls function at previous frame.
prev_file_path = prev_stack_frame[0]
prev_function = prev_stack_frame[2]
prev_function_start_line = prev_stack_frame[4]
curr_file_path = stack_frame[0]
curr_line_number = stack_frame[1]
# Skip all calls up to apply_op since they are the same for all ops.
if not after_apply_op:
if prev_function == 'apply_op':
after_apply_op = True
continue
location_index = self._locations.index_of(
curr_file_path, curr_line_number,
prev_function, prev_file_path, prev_function_start_line)
location_ids.append(location_index)
samples.add(datum, location_ids)
sample_type_description = 'count'
sample_type = pprof_profile.sample_type.add()
sample_type.type = self._string_table.index_of(sample_type_description)
sample_type.unit = self._string_table.index_of('count')
sample_type_description = 'all_time'
sample_type = pprof_profile.sample_type.add()
sample_type.type = self._string_table.index_of(sample_type_description)
sample_type.unit = self._string_table.index_of('nanoseconds')
sample_type_description = 'op_time'
sample_type = pprof_profile.sample_type.add()
sample_type.type = self._string_table.index_of(sample_type_description)
sample_type.unit = self._string_table.index_of('nanoseconds')
pprof_profile.string_table.extend(self._string_table.string_table())
pprof_profile.sample.extend(samples.get_sample_protos())
pprof_profile.function.extend(self._functions.function_protos())
pprof_profile.location.extend(self._locations.location_protos())
return pprof_profile
def _get_profile_data_generator(self):
"""Get function that generates `ProfileDatum` objects.
Returns:
A function that generates `ProfileDatum` objects.
"""
node_to_traceback = defaultdict(list)
node_to_op_type = defaultdict(str)
for op in self._graph.get_operations():
node_to_traceback[op.name] = op.traceback_with_start_lines
node_to_op_type[op.name] = op.type
def profile_data_generator(device_step_stats):
for node_stats in device_step_stats.node_stats:
if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
continue
yield ProfileDatum(
node_stats,
node_to_op_type[node_stats.node_name],
node_to_traceback[node_stats.node_name])
return profile_data_generator
def get_profiles(graph, run_metadata):
"""Generate profiles in pprof format.
See https://github.com/google/pprof/blob/master/proto/profile.proto
for pprof proto format.
Args:
graph: A `Graph` object.
run_metadata: A `RunMetadata` proto.
Returns:
A dictionary mapping from device name to pprof proto for that device.
"""
return PprofProfiler(graph, run_metadata).profile()
def profile(graph, run_metadata, output_dir=None):
"""Generate profiles in pprof format.
See https://github.com/google/pprof/blob/master/proto/profile.proto
for pprof proto format.
Args:
graph: A `Graph` object.
run_metadata: A `RunMetadata` proto.
output_dir: (string) Directory to output pprof profile to.
Profile files for each device will be stored in compressed
serialized proto format. If output_dir is None, profile protos
will be printed to stdout instead.
Returns:
List of output files created by this profile call.
(Note: this list will be empty if output_dir is None)
"""
profiles = get_profiles(graph, run_metadata)
output_file_template = None
if output_dir:
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
time_suffix = time.strftime('%Y%m%d%H%M%S')
output_file_template = os.path.join(
output_dir, '%s_' + time_suffix + '.pb.gz')
profile_files = []
for device, pprof_proto in profiles.items():
if output_file_template is None:
print('No output directory specified, printing to stdout instead.')
print(pprof_proto)
else:
device_name = str(device).strip('/').translate(
maketrans('/:', '__'))
profile_file = output_file_template % device_name
profile_files.append(profile_file)
with gzip.open(profile_file, 'w') as output_file:
print('Writing profile to %s...' % profile_file)
output_file.write(pprof_proto.SerializeToString())
return profile_files

View File

@ -0,0 +1,164 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for pprof_profiler."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
from pprof import profile_pb2
from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler
from tensorflow.core.framework import step_stats_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class PprofProfilerTest(test.TestCase):
def testDataEmpty(self):
output_dir = test.get_temp_dir()
run_metadata = config_pb2.RunMetadata()
graph = test.mock.MagicMock()
graph.get_operations.return_value = []
profiles = pprof_profiler.get_profiles(graph, run_metadata)
self.assertEquals(0, len(profiles))
profile_files = pprof_profiler.profile(
graph, run_metadata, output_dir)
self.assertEquals(0, len(profile_files))
def testRunMetadataEmpty(self):
output_dir = test.get_temp_dir()
run_metadata = config_pb2.RunMetadata()
graph = test.mock.MagicMock()
op1 = test.mock.MagicMock()
op1.name = 'Add/123'
op1.traceback = [('a/b/file1', 10, 'some_var')]
op1.type = 'add'
graph.get_operations.return_value = [op1]
profiles = pprof_profiler.get_profiles(graph, run_metadata)
self.assertEquals(0, len(profiles))
profile_files = pprof_profiler.profile(
graph, run_metadata, output_dir)
self.assertEquals(0, len(profile_files))
def testValidProfile(self):
output_dir = test.get_temp_dir()
run_metadata = config_pb2.RunMetadata()
node1 = step_stats_pb2.NodeExecStats(
node_name='Add/123',
op_start_rel_micros=3,
op_end_rel_micros=5,
all_end_rel_micros=4)
run_metadata = config_pb2.RunMetadata()
device1 = run_metadata.step_stats.dev_stats.add()
device1.device = 'deviceA'
device1.node_stats.extend([node1])
graph = test.mock.MagicMock()
op1 = test.mock.MagicMock()
op1.name = 'Add/123'
op1.traceback = [
('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
op1.type = 'add'
graph.get_operations.return_value = [op1]
expected_proto = """sample_type {
type: 5
unit: 5
}
sample_type {
type: 6
unit: 7
}
sample_type {
type: 8
unit: 7
}
sample {
value: 1
value: 4
value: 2
label {
key: 1
str: 2
}
label {
key: 3
str: 4
}
}
string_table: ""
string_table: "node_name"
string_table: "Add/123"
string_table: "op_type"
string_table: "add"
string_table: "count"
string_table: "all_time"
string_table: "nanoseconds"
string_table: "op_time"
string_table: "Device 1 of 1: deviceA"
comment: 9
"""
# Test with protos
profiles = pprof_profiler.get_profiles(graph, run_metadata)
self.assertEquals(1, len(profiles))
self.assertTrue('deviceA' in profiles)
self.assertEquals(expected_proto, str(profiles['deviceA']))
# Test with files
profile_files = pprof_profiler.profile(
graph, run_metadata, output_dir)
self.assertEquals(1, len(profile_files))
with gzip.open(profile_files[0]) as profile_file:
profile_contents = profile_file.read()
profile = profile_pb2.Profile()
profile.ParseFromString(profile_contents)
self.assertEquals(expected_proto, str(profile))
def testProfileWithWhileLoop(self):
options = config_pb2.RunOptions()
options.trace_level = config_pb2.RunOptions.FULL_TRACE
run_metadata = config_pb2.RunMetadata()
num_iters = 5
with self.test_session() as sess:
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, num_iters)
b = lambda i: math_ops.add(i, 1)
r = control_flow_ops.while_loop(c, b, [i])
sess.run(r, options=options, run_metadata=run_metadata)
profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
self.assertEquals(1, len(profiles))
profile = next(iter(profiles.values()))
add_samples = [] # Samples for the while/Add node
for sample in profile.sample:
if profile.string_table[sample.label[0].str] == 'while/Add':
add_samples.append(sample)
# Values for same nodes are aggregated.
self.assertEquals(1, len(add_samples))
# Value of "count" should be equal to number of iterations.
self.assertEquals(num_iters, add_samples[0].value[0])
if __name__ == '__main__':
test.main()

View File

@ -272,7 +272,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
self_.qpn = qp_->qp_num; self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff; self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
union ibv_gid gid; union ibv_gid gid;
CHECK(!ibv_query_gid(adapter_->context_, (uint8_t) 1, 0, &gid)) << "Query gid"; CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
<< "Query gid";
self_.snp = gid.global.subnet_prefix; self_.snp = gid.global.subnet_prefix;
self_.iid = gid.global.interface_id; self_.iid = gid.global.interface_id;
} }
@ -479,7 +480,7 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.dest_qp_num = remoteAddr.qpn; attr.dest_qp_num = remoteAddr.qpn;
attr.rq_psn = remoteAddr.psn; attr.rq_psn = remoteAddr.psn;
attr.max_dest_rd_atomic = 1; attr.max_dest_rd_atomic = 1;
attr.min_rnr_timer = 12; attr.min_rnr_timer = 12;
attr.ah_attr.is_global = 1; attr.ah_attr.is_global = 1;
attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp; attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid; attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;

View File

@ -248,8 +248,8 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
tdata.size(), do_nothing); tdata.size(), do_nothing);
slices[1] = ::grpc::Slice(s1, ::grpc::Slice::STEAL_REF); slices[1] = ::grpc::Slice(s1, ::grpc::Slice::STEAL_REF);
gpr_slice s2 = gpr_slice_new(const_cast<TensorBuffer*>(buf), gpr_slice s2 =
0, unref_tensorbuffer); gpr_slice_new(const_cast<TensorBuffer*>(buf), 0, unref_tensorbuffer);
slices[2] = ::grpc::Slice(s2, ::grpc::Slice::STEAL_REF); slices[2] = ::grpc::Slice(s2, ::grpc::Slice::STEAL_REF);
num_slices += 2; num_slices += 2;
} }

View File

@ -135,6 +135,22 @@ cc_library(
], ],
) )
cc_library(
name = "virtual_placer",
srcs = ["virtual_placer.cc"],
hdrs = ["virtual_placer.h"],
visibility = ["//visibility:public"],
deps = [
":op_performance_data_cc",
":utils",
"//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler/clusters:cluster",
],
)
cc_library( cc_library(
name = "virtual_scheduler", name = "virtual_scheduler",
srcs = ["virtual_scheduler.cc"], srcs = ["virtual_scheduler.cc"],
@ -194,3 +210,24 @@ cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
) )
cc_library(
name = "analytical_cost_estimator",
srcs = ["analytical_cost_estimator.cc"],
hdrs = ["analytical_cost_estimator.h"],
visibility = ["//visibility:public"],
deps = [
":cost_estimator",
":graph_properties",
":op_level_cost_estimator",
":op_performance_data_cc",
":utils",
":virtual_placer",
":virtual_scheduler",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
],
)

View File

@ -0,0 +1,128 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
#include <limits>
#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace grappler {
AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
bool use_static_shapes)
: cluster_(cluster), use_static_shapes_(use_static_shapes) {}
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
item_ = item;
return Status::OK();
}
Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
CostGraphDef* cost_graph,
Costs* costs) const {
GrapplerItem item = item_;
item.graph = optimized_graph;
GraphProperties properties(item);
Status status;
if (use_static_shapes_) {
status = properties.InferStatically();
} else {
status = properties.InferDynamically(cluster_);
}
if (!status.ok()) {
costs->execution_time = Costs::Duration::max();
return status;
}
std::unordered_map<string, CostGraphDef::Node*> name_to_cost;
if (cost_graph) {
for (auto& node : *cost_graph->mutable_node()) {
name_to_cost[node.name()] = &node;
}
}
std::vector<string> inaccurate_nodes;
VirtualScheduler scheduler(optimized_graph, item_.fetch);
VirtualPlacer placer(cluster_);
Costs node_costs;
do {
const NodeDef* node = scheduler.GetCurrNode();
std::vector<OpInfo::TensorProperties> inputs =
properties.GetInputProperties(node->name());
OpInfo::DeviceProperties device = placer.get_device(*node);
OpInfo op_info;
op_info.set_op(node->op());
*op_info.mutable_attr() = node->attr();
for (auto& input : inputs) {
op_info.add_inputs()->Swap(&input);
}
op_info.mutable_device()->Swap(&device);
node_costs = node_estimator_.PredictCosts(op_info);
if (node_costs.inaccurate) {
inaccurate_nodes.push_back(node->name());
}
if (cost_graph) {
auto it = name_to_cost.find(node->name());
CostGraphDef::Node* cost_node;
if (it != name_to_cost.end()) {
cost_node = it->second;
} else {
cost_node = cost_graph->add_node();
cost_node->set_name(node->name());
}
string device_name = properties.GetDeviceName(node->name());
cost_node->set_device(device_name);
cost_node->set_compute_cost(
node_costs.execution_time.asMicroSeconds().count());
cost_node->set_compute_time(
node_costs.compute_time.asMicroSeconds().count());
cost_node->set_memory_time(
node_costs.memory_time.asMicroSeconds().count());
std::vector<OpInfo::TensorProperties> outputs =
properties.GetOutputProperties(node->name());
for (const auto& output : outputs) {
auto output_info = cost_node->add_output_info();
output_info->set_dtype(output.dtype());
auto shape = output_info->mutable_shape();
*shape = output.shape();
}
}
} while (scheduler.MarkCurrNodeExecuted(node_costs));
*costs = scheduler.Summary();
VLOG(1) << inaccurate_nodes.size() << " out of "
<< optimized_graph.node_size()
<< " nodes have inaccurate time estimation";
for (const auto& node : inaccurate_nodes) {
VLOG(2) << "Node with inaccurate time estimation: " << node;
}
return Status::OK();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
#define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class CostGraphDef;
class GraphDef;
} // namespace tensorflow
namespace tensorflow {
namespace grappler {
class Cluster;
struct GrapplerItem;
// Estimate the cost of running a Grappler item based on the theoretical
// performance of the hardware that will run the model.
class AnalyticalCostEstimator : public CostEstimator {
public:
// Does not take ownership of cluster.
explicit AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
~AnalyticalCostEstimator() override {}
// Initalizes the estimator for the specified grappler item.
// This implementation always returns OK.
Status Initialize(const GrapplerItem& item) override;
// Predict the performance of each node of the optimized graph and annotate
// the CostGraphDef with the corresponding estimates. Also returns the
// expected latency for the whole graph.
Status PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph,
Costs* overall_latency) const override;
private:
Cluster* cluster_; // Not owned.
GrapplerItem item_;
OpLevelCostEstimator node_estimator_;
bool use_static_shapes_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_

View File

@ -80,7 +80,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo(); const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo();
// Check if vector instructions are available, and refine performance // Check if vector instructions are available, and refine performance
// prediction based on this. // prediction based on this.
gflops = local_cpu.num_cores() * local_cpu.frequency(); // Frequencies are stored in MHz in the DeviceProperties.
gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3;
if (bandwidth < 0) { if (bandwidth < 0) {
if (local_cpu.bandwidth() > 0) { if (local_cpu.bandwidth() > 0) {
bandwidth = local_cpu.bandwidth() / 1e6; bandwidth = local_cpu.bandwidth() / 1e6;
@ -105,7 +106,7 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
// Pascal. // Pascal.
cores_per_multiprocessor = 64; cores_per_multiprocessor = 64;
} }
gflops = local_gpu.num_cores() * local_gpu.frequency() * gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 *
cores_per_multiprocessor * kOpsPerMac; cores_per_multiprocessor * kOpsPerMac;
if (bandwidth < 0) { if (bandwidth < 0) {
CHECK(local_gpu.bandwidth() > 0); CHECK(local_gpu.bandwidth() > 0);

View File

@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() {
// Combine cpu family and model into the model string. // Combine cpu family and model into the model string.
device.set_model( device.set_model(
strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
device.set_frequency(port::NominalCPUFrequency() * 1e-9); device.set_frequency(port::NominalCPUFrequency() * 1e-6);
device.set_num_cores(port::NumSchedulableCPUs()); device.set_num_cores(port::NumSchedulableCPUs());
device.set_l1_cache_size(Eigen::l1CacheSize()); device.set_l1_cache_size(Eigen::l1CacheSize());
device.set_l2_cache_size(Eigen::l2CacheSize()); device.set_l2_cache_size(Eigen::l2CacheSize());
@ -175,7 +175,7 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) {
if (error == cudaSuccess) { if (error == cudaSuccess) {
device.set_vendor("NVidia"); device.set_vendor("NVidia");
device.set_model(properties.name); device.set_model(properties.name);
device.set_frequency(properties.clockRate / 1000); device.set_frequency(properties.clockRate * 1e-3);
device.set_num_cores(properties.multiProcessorCount); device.set_num_cores(properties.multiProcessorCount);
device.set_num_registers(properties.regsPerMultiprocessor); device.set_num_registers(properties.regsPerMultiprocessor);
// For compute capability less than 5, l1 cache size is configurable to // For compute capability less than 5, l1 cache size is configurable to

View File

@ -0,0 +1,57 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace grappler {
VirtualPlacer::VirtualPlacer(Cluster* cluster) : has_gpu_(false) {
devices_["CPU"] = GetLocalCPUInfo();
if (GetNumAvailableGPUs() > 0) {
has_gpu_ = true;
devices_["GPU"] = GetLocalGPUInfo(0);
}
unknown_device_.set_type("UNKNOWN");
}
const OpInfo::DeviceProperties& VirtualPlacer::get_device(
const NodeDef& node) const {
string device_type;
DeviceNameUtils::ParsedName parsed;
if (!node.device().empty() &&
DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
device_type = parsed.type;
} else {
if (has_gpu_) {
device_type = "GPU";
} else {
device_type = "CPU";
}
}
auto it = devices_.find(device_type);
if (it == devices_.end()) {
return unknown_device_;
}
return it->second;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,45 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
#include <unordered_map>
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class NodeDef;
namespace grappler {
class Cluster;
// The virtual placer emulates the behavior of the TF placer.
class VirtualPlacer {
public:
VirtualPlacer(Cluster* cluster);
const OpInfo::DeviceProperties& get_device(const NodeDef& node) const;
private:
std::unordered_map<string, OpInfo::DeviceProperties> devices_;
bool has_gpu_;
OpInfo::DeviceProperties unknown_device_;
};
} // namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/core/kernels/crop_and_resize_op.h" #include "tensorflow/core/kernels/crop_and_resize_op.h"
#include <functional>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
@ -26,10 +29,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA #if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA
@ -37,41 +43,67 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice; typedef Eigen::GpuDevice GPUDevice;
using Callback = std::function<void()>;
static inline void ParseAndCheckBoxSizes(OpKernelContext* context, namespace {
const Tensor& boxes,
const Tensor& box_ind, static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
int* num_boxes) { const Tensor& box_index,
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { int* num_boxes) {
if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
*num_boxes = 0; *num_boxes = 0;
return; return Status::OK();
} }
// The shape of 'boxes' is [num_boxes, 4]. // The shape of 'boxes' is [num_boxes, 4].
OP_REQUIRES(context, boxes.dims() == 2, if (boxes.dims() != 2) {
errors::InvalidArgument("boxes must be 2-D", return errors::InvalidArgument("boxes must be 2-D",
boxes.shape().DebugString())); boxes.shape().DebugString());
}
*num_boxes = boxes.dim_size(0); *num_boxes = boxes.dim_size(0);
OP_REQUIRES(context, boxes.dim_size(1) == 4, if (boxes.dim_size(1) != 4) {
errors::InvalidArgument("boxes must have 4 columns")); return errors::InvalidArgument("boxes must have 4 columns");
}
// The shape of 'box_ind' is [num_boxes]. // The shape of 'box_index' is [num_boxes].
OP_REQUIRES(context, box_ind.dims() == 1, if (box_index.dims() != 1) {
errors::InvalidArgument("box_ind must be 1-D", return errors::InvalidArgument("box_index must be 1-D",
box_ind.shape().DebugString())); box_index.shape().DebugString());
OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, }
errors::InvalidArgument("box_ind has incompatible shape")); if (box_index.dim_size(0) != *num_boxes) {
return errors::InvalidArgument("box_index has incompatible shape");
}
return Status::OK();
} }
// Verifies that all values in box_ind are in [0, batch). // Conditionally calls the compute callback if all values in box_index are in
// [0, batch_size) then calls done.
template <typename Device> template <typename Device>
inline void CheckValidBoxInd( inline void RunIfBoxIndexIsValid(
OpKernelContext* context, OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch); int batch_size, Callback compute, Callback done);
// Specialization of CheckValidBoxIndex for a CPUDevice.
template <>
inline void RunIfBoxIndexIsValid<CPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch_size, Callback compute, Callback done) {
const int num_boxes = box_index.dimension(0);
for (int b = 0; b < num_boxes; ++b) {
OP_REQUIRES_ASYNC(
context, FastBoundsCheck(box_index(b), batch_size),
errors::OutOfRange("box_index has values outside [0, batch_size)"),
done);
}
compute();
done();
}
} // namespace
template <typename Device, typename T> template <typename Device, typename T>
class CropAndResizeOp : public OpKernel { class CropAndResizeOp : public AsyncOpKernel {
public: public:
explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { explicit CropAndResizeOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
string method; string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear", OP_REQUIRES(context, method == "bilinear",
@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
&extrapolation_value_)); &extrapolation_value_));
} }
void Compute(OpKernelContext* context) override { void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'image' is [batch, image_height, image_width, channels]. // The shape of 'image' is [batch_size, image_height, image_width,
// channels].
const Tensor& image = context->input(0); const Tensor& image = context->input(0);
OP_REQUIRES(context, image.dims() == 4,
errors::InvalidArgument("input image must be 4-D",
image.shape().DebugString()));
const int batch = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
OP_REQUIRES(context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"));
// The shape of 'boxes' is [num_boxes, 4]. // The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1); const Tensor& boxes = context->input(1);
// The shape of 'box_index' is [num_boxes].
// The shape of 'box_ind' is [num_boxes]. const Tensor& box_index = context->input(2);
const Tensor& box_ind = context->input(2);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
// The shape of 'crop_size' is [2]. // The shape of 'crop_size' is [2].
const Tensor& crop_size = context->input(3); const Tensor& crop_size = context->input(3);
OP_REQUIRES(context, crop_size.dims() == 1, // Validate inputs dimensions.
errors::InvalidArgument("crop_size must be 1-D", OP_REQUIRES_ASYNC(context, image.dims() == 4,
crop_size.shape().DebugString())); errors::InvalidArgument("input image must be 4-D",
OP_REQUIRES(context, crop_size.dim_size(0) == 2, image.shape().DebugString()),
errors::InvalidArgument("crop_size must have two elements", done);
crop_size.shape().DebugString())); const int batch_size = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
const int depth = image.dim_size(3);
OP_REQUIRES_ASYNC(
context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive"), done);
int num_boxes = 0;
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
errors::InvalidArgument("crop_size must be 1-D",
crop_size.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(
context, crop_size.dim_size(0) == 2,
errors::InvalidArgument("crop_size must have two elements",
crop_size.shape().DebugString()),
done);
// Copy and validate crop sizes.
auto crop_size_vec = crop_size.vec<int32>(); auto crop_size_vec = crop_size.vec<int32>();
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
OP_REQUIRES(context, crop_height > 0 && crop_width > 0, OP_REQUIRES_ASYNC(
errors::InvalidArgument("crop dimensions must be positive")); context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("crop dimensions must be positive"), done);
// Allocate output tensor. // Allocate output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK_ASYNC(
context, context,
context->allocate_output( context->allocate_output(
0, TensorShape({num_boxes, crop_height, crop_width, depth}), 0, TensorShape({num_boxes, crop_height, crop_width, depth}),
&output)); &output),
done);
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>(); auto compute_callback = [this, context, output]() {
typename TTypes<float, 2>::ConstTensor boxes_data = const Tensor& image = context->input(0);
boxes.tensor<float, 2>(); const Tensor& boxes = context->input(1);
typename TTypes<int32, 1>::ConstTensor box_ind_data = const Tensor& box_index = context->input(2);
box_ind.tensor<int32, 1>(); const bool status = functor::CropAndResize<Device, T>()(
typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>(); context->eigen_device<Device>(), image.tensor<T, 4>(),
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
extrapolation_value_, output->tensor<float, 4>());
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch); RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
bool status = functor::CropAndResize<Device, T>()( std::move(done));
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
extrapolation_value_, crops_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeKernel."));
}
} }
private: private:
@ -155,10 +195,10 @@ template <typename T>
struct CropAndResize<CPUDevice, T> { struct CropAndResize<CPUDevice, T> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image, bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind, typename TTypes<int32, 1>::ConstTensor box_index,
float extrapolation_value, float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) { typename TTypes<float, 4>::Tensor crops) {
const int batch = image.dimension(0); const int batch_size = image.dimension(0);
const int image_height = image.dimension(1); const int image_height = image.dimension(1);
const int image_width = image.dimension(2); const int image_width = image.dimension(2);
@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
const float y2 = boxes(b, 2); const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3); const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b); const int32 b_in = box_index(b);
if (b_in < 0 || b_in >= batch) { if (!FastBoundsCheck(b_in, batch_size)) {
continue; continue;
} }
@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
return true; return true;
} }
}; };
} // namespace functor } // namespace functor
template <typename Device, typename T> template <typename Device, typename T>
class CropAndResizeGradImageOp : public OpKernel { class CropAndResizeGradImageOp : public AsyncOpKernel {
public: public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context) explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
: OpKernel(context) { : AsyncOpKernel(context) {
string method; string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear", OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method)); errors::InvalidArgument("method must be 'bilinear'", method));
} }
void Compute(OpKernelContext* context) override { void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0); const Tensor& grads = context->input(0);
OP_REQUIRES(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()));
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"));
// The shape of 'boxes' is [num_boxes, 4]. // The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1); const Tensor& boxes = context->input(1);
// The shape of 'box_index' is [num_boxes].
// The shape of 'box_ind' is [num_boxes]. const Tensor& box_index = context->input(2);
const Tensor& box_ind = context->input(2);
int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
OP_REQUIRES(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"));
// The shape of 'image_size' is [4]. // The shape of 'image_size' is [4].
const Tensor& image_size = context->input(3); const Tensor& image_size = context->input(3);
OP_REQUIRES(context, image_size.dims() == 1,
errors::InvalidArgument("image_size must be 1-D",
image_size.shape().DebugString()));
OP_REQUIRES(context, image_size.dim_size(0) == 4,
errors::InvalidArgument("image_size must have 4 elements",
image_size.shape().DebugString()));
// Validate input shapes.
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()),
done);
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
OP_REQUIRES_ASYNC(
context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"), done);
int num_boxes = 0;
OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape"),
done);
OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
errors::InvalidArgument("image_size must be 1-D",
image_size.shape().DebugString()),
done);
OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
errors::InvalidArgument("image_size must have 4 elements",
image_size.shape().DebugString()),
done);
auto image_size_vec = image_size.vec<int32>(); auto image_size_vec = image_size.vec<int32>();
const int batch = internal::SubtleMustCopy(image_size_vec(0)); const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_height = internal::SubtleMustCopy(image_size_vec(1));
const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int image_width = internal::SubtleMustCopy(image_size_vec(2));
const int depth = internal::SubtleMustCopy(image_size_vec(3)); const int depth = internal::SubtleMustCopy(image_size_vec(3));
OP_REQUIRES_ASYNC(
OP_REQUIRES(context, image_height > 0 && image_width > 0, context, image_height > 0 && image_width > 0,
errors::InvalidArgument("image dimensions must be positive")); errors::InvalidArgument("image dimensions must be positive"), done);
OP_REQUIRES( OP_REQUIRES_ASYNC(
context, grads.dim_size(3) == depth, context, grads.dim_size(3) == depth,
errors::InvalidArgument("image_size and grads are incompatible")); errors::InvalidArgument("image_size and grads are incompatible"), done);
// Allocate output tensor. // Allocate output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK( OP_REQUIRES_OK_ASYNC(
context, context->allocate_output( context,
0, TensorShape({batch, image_height, image_width, depth}), context->allocate_output(
&output)); 0, TensorShape({batch_size, image_height, image_width, depth}),
&output),
done);
typename TTypes<float, 4>::ConstTensor grads_data = auto compute_callback = [context, output]() {
grads.tensor<float, 4>(); const Tensor& grads = context->input(0);
typename TTypes<float, 2>::ConstTensor boxes_data = const Tensor& boxes = context->input(1);
boxes.tensor<float, 2>(); const Tensor& box_index = context->input(2);
typename TTypes<int32, 1>::ConstTensor box_ind_data = const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
box_ind.tensor<int32, 1>(); context->eigen_device<Device>(), grads.tensor<float, 4>(),
typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>(); boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
output->tensor<T, 4>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropImage kernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch); RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
bool status = functor::CropAndResizeBackpropImage<Device, T>()( std::move(done));
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
}
} }
}; };
@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
bool operator()(const CPUDevice& d, bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind, typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<T, 4>::Tensor grads_image) { typename TTypes<T, 4>::Tensor grads_image) {
const int batch = grads_image.dimension(0); const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1); const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2); const int image_width = grads_image.dimension(2);
@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float y2 = boxes(b, 2); const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3); const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b); const int32 b_in = box_index(b);
if (b_in < 0 || b_in >= batch) { if (!FastBoundsCheck(b_in, batch_size)) {
continue; continue;
} }
@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
return true; return true;
} }
}; };
} // namespace functor } // namespace functor
template <typename Device, typename T> template <typename Device, typename T>
class CropAndResizeGradBoxesOp : public OpKernel { class CropAndResizeGradBoxesOp : public AsyncOpKernel {
public: public:
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
: OpKernel(context) { : AsyncOpKernel(context) {
string method; string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear", OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method)); errors::InvalidArgument("method must be 'bilinear'", method));
} }
void Compute(OpKernelContext* context) override { void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0); const Tensor& grads = context->input(0);
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(2);
// The shape of 'box_index' is [num_boxes].
const Tensor& box_index = context->input(3);
// The shape of 'image' is [batch_size, image_height, image_width, depth].
const Tensor& image = context->input(1);
OP_REQUIRES(context, grads.dims() == 4, // Validate input shapes.
errors::InvalidArgument("grads image must be 4-D", OP_REQUIRES_ASYNC(context, grads.dims() == 4,
grads.shape().DebugString())); errors::InvalidArgument("grads image must be 4-D",
grads.shape().DebugString()),
done);
const int crop_height = grads.dim_size(1); const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2); const int crop_width = grads.dim_size(2);
const int depth = grads.dim_size(3); const int depth = grads.dim_size(3);
OP_REQUIRES(context, crop_height > 0 && crop_width > 0, OP_REQUIRES_ASYNC(
errors::InvalidArgument("grads dimensions must be positive")); context, crop_height > 0 && crop_width > 0,
errors::InvalidArgument("grads dimensions must be positive"), done);
// The shape of 'image' is [batch, image_height, image_width, depth]. OP_REQUIRES_ASYNC(context, image.dims() == 4,
const Tensor& image = context->input(1); errors::InvalidArgument("input image must be 4-D",
OP_REQUIRES(context, image.dims() == 4, image.shape().DebugString()),
errors::InvalidArgument("input image must be 4-D", done);
image.shape().DebugString())); const int batch_size = image.dim_size(0);
const int batch = image.dim_size(0);
const int image_height = image.dim_size(1); const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2); const int image_width = image.dim_size(2);
OP_REQUIRES(context, image_height > 0 && image_width > 0, OP_REQUIRES_ASYNC(
errors::InvalidArgument("image dimensions must be positive")); context, image_height > 0 && image_width > 0,
OP_REQUIRES(context, image.dim_size(3) == depth, errors::InvalidArgument("image dimensions must be positive"), done);
errors::InvalidArgument("image, grads depth differ")); OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
errors::InvalidArgument("image, grads depth differ"),
// The shape of 'boxes' is [num_boxes, 4]. done);
const Tensor& boxes = context->input(2);
// The shape of 'box_ind' is [num_boxes].
const Tensor& box_ind = context->input(3);
int num_boxes = 0; int num_boxes = 0;
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); OP_REQUIRES_OK_ASYNC(
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
OP_REQUIRES( OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes, context, grads.dim_size(0) == num_boxes,
errors::InvalidArgument("boxes and grads have incompatible shape")); errors::InvalidArgument("boxes and grads have incompatible shape"),
done);
// Allocate output tensor. // Allocate output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output( OP_REQUIRES_OK_ASYNC(
0, TensorShape({num_boxes, 4}), &output)); context,
context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
done);
typename TTypes<float, 4>::ConstTensor grads_data = auto compute_callback = [context, output]() {
grads.tensor<float, 4>(); const Tensor& grads = context->input(0);
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>(); const Tensor& image = context->input(1);
typename TTypes<float, 2>::ConstTensor boxes_data = const Tensor& boxes = context->input(2);
boxes.tensor<float, 2>(); const Tensor& box_index = context->input(3);
typename TTypes<int32, 1>::ConstTensor box_ind_data = const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
box_ind.tensor<int32, 1>(); context->eigen_device<Device>(), grads.tensor<float, 4>(),
typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>(); image.tensor<T, 4>(), boxes.tensor<float, 2>(),
box_index.tensor<int32, 1>(), output->tensor<float, 2>());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launch CropAndResizeBackpropBoxes kernel."));
}
};
CheckValidBoxInd<Device>(context, box_ind_data, batch); RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
batch_size, std::move(compute_callback),
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()( std::move(done));
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
box_ind_data, output_data);
if (!status) {
context->SetStatus(
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
}
} }
}; };
@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads, typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes, typename TTypes<float, 2>::ConstTensor boxes,
typename TTypes<int32, 1>::ConstTensor box_ind, typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<float, 2>::Tensor grads_boxes) { typename TTypes<float, 2>::Tensor grads_boxes) {
const int batch = image.dimension(0); const int batch_size = image.dimension(0);
const int image_height = image.dimension(1); const int image_height = image.dimension(1);
const int image_width = image.dimension(2); const int image_width = image.dimension(2);
@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float y2 = boxes(b, 2); const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3); const float x2 = boxes(b, 3);
const int32 b_in = box_ind(b); const int32 b_in = box_index(b);
if (b_in < 0 || b_in >= batch) { if (!FastBoundsCheck(b_in, batch_size)) {
continue; continue;
} }
@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
return true; return true;
} }
}; };
} // namespace functor } // namespace functor
// Specialization of CheckValidBoxInd for a CPUDevice. #define REGISTER_KERNEL(T) \
template <> REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
inline void CheckValidBoxInd<CPUDevice>( .Device(DEVICE_CPU) \
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind, .TypeConstraint<T>("T") \
int batch) { .HostMemory("crop_size"), \
const int num_boxes = box_ind.dimension(0); CropAndResizeOp<CPUDevice, T>); \
for (int b = 0; b < num_boxes; ++b) { \
OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
errors::OutOfRange("box_ind has values outside [0, batch)")); .Device(DEVICE_CPU) \
} .TypeConstraint<T>("T"), \
}
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.HostMemory("crop_size"), \
CropAndResizeOp<CPUDevice, T>); \
\
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>); CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
#if GOOGLE_CUDA #if GOOGLE_CUDA
// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. // Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
namespace functor { namespace functor {
template <> template <>
void CheckValidBoxIndHelper<GPUDevice>::operator()( void CheckValidBoxIndexHelper<GPUDevice>::operator()(
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind, const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
int batch, typename TTypes<bool, 0>::Tensor isvalid); int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
extern template struct CheckValidBoxIndHelper<GPUDevice>; extern template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor } // namespace functor
// Specialization of CheckValidBoxInd for a GPUDevice. namespace {
// Specialization of CheckValidBoxIndex for a GPUDevice.
template <> template <>
inline void CheckValidBoxInd<GPUDevice>( inline void RunIfBoxIndexIsValid<GPUDevice>(
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind, OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
int batch) { int batch_size, Callback compute, Callback done) {
const int num_boxes = box_ind.dimension(0); const int num_boxes = box_index.dimension(0);
if (num_boxes == 0) { if (num_boxes == 0) {
compute();
done();
return; return;
} }
Tensor isvalid_tensor;
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<bool>::value,
TensorShape({}), &isvalid_tensor));
typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>(); Tensor isvalid_dev_tensor;
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_dev_tensor),
done);
typename TTypes<bool, 0>::Tensor isvalid_dev =
isvalid_dev_tensor.tensor<bool, 0>();
functor::CheckValidBoxIndHelper<GPUDevice>()( // Run the actual box check on the device.
context->eigen_device<GPUDevice>(), box_ind, batch, isvalid); functor::CheckValidBoxIndexHelper<GPUDevice>()(
context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
// Copy the result back to the host.
auto* stream = context->op_device_context()->stream(); auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); OP_REQUIRES_ASYNC(context, stream,
errors::Internal("No GPU stream available."), done);
Tensor isvalid_host_tensor;
// Use pinned host memory on the host to avoid unnecessary
// synchronization.
AllocatorAttributes alloc_attr;
alloc_attr.set_on_host(true);
alloc_attr.set_gpu_compatible(true);
OP_REQUIRES_OK_ASYNC(
context,
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_host_tensor, alloc_attr),
done);
perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
sizeof(bool));
const bool status =
stream
->ThenMemcpy(
isvalid_host_tensor.scalar<bool>().data() /* destination */,
wrapped /* source */, sizeof(bool))
.ok();
OP_REQUIRES_ASYNC(
context, status,
errors::Internal("Failed to launch copy of isvalid from device to host."),
done);
bool isvalid_host = false; auto wrapped_callback = [context, isvalid_host_tensor, compute, done]() {
perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), const bool isvalid = isvalid_host_tensor.scalar<bool>()();
sizeof(bool)); OP_REQUIRES_ASYNC(
stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); context, isvalid,
stream->BlockHostUntilDone(); errors::OutOfRange("box_index has values outside [0, batch_size)"),
done);
compute();
done();
};
OP_REQUIRES(context, stream->ok(), context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
errors::Internal("cudaMemcpy from device to host failed")); stream, wrapped_callback);
OP_REQUIRES(context, isvalid_host,
errors::OutOfRange("box_ind has values outside [0, batch)"));
} }
} // namespace
#define REGISTER_KERNEL(T) \ #define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_GPU) \ .Device(DEVICE_GPU) \

View File

@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
}; };
template <typename Device> template <typename Device>
struct CheckValidBoxIndHelper { struct CheckValidBoxIndexHelper {
// Checks if all values in box_ind are in [0, batch). // Checks if all values in box_index are in [0, batch).
void operator()(const Device& d, void operator()(const Device& d,
typename TTypes<int32, 1>::ConstTensor box_ind, int batch, typename TTypes<int32, 1>::ConstTensor box_index, int batch,
typename TTypes<bool, 0>::Tensor isvalid) { typename TTypes<bool, 0>::Tensor isvalid) {
isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
} }
}; };

View File

@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS #undef DEFINE_GPU_SPECS
template struct CheckValidBoxIndHelper<GPUDevice>; template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow

View File

@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
Status s = RunOpKernel(); Status s = RunOpKernel();
ASSERT_FALSE(s.ok()); ASSERT_FALSE(s.ok());
EXPECT_TRUE( EXPECT_TRUE(
StringPiece(s.ToString()).contains("box_ind has incompatible shape")) StringPiece(s.ToString()).contains("box_index has incompatible shape"))
<< s; << s;
} }
@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
Status s = RunOpKernel(); Status s = RunOpKernel();
ASSERT_FALSE(s.ok()); ASSERT_FALSE(s.ok());
EXPECT_TRUE(StringPiece(s.ToString()) EXPECT_TRUE(StringPiece(s.ToString())
.contains("box_ind has values outside [0, batch)")) .contains("box_index has values outside [0, batch_size)"))
<< s; << s;
} }
// TODO(zhengxq, rmlarsen): Add a benchmark.
} // namespace tensorflow } // namespace tensorflow

View File

@ -20,4 +20,4 @@ REGISTER2(BinaryOp, CPU, "Atan2", functor::atan2, float, double);
#if GOOGLE_CUDA #if GOOGLE_CUDA
REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double); REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double);
#endif #endif
} // namespace tensorflow } // namespace tensorflow

View File

@ -23,4 +23,4 @@ DEFINE_BINARY2(atan2, float, double);
} // namespace functor } // namespace functor
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA

View File

@ -155,7 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
const int col_dimension = input_rank - 1; const int col_dimension = input_rank - 1;
const int64 num_rows = in.dim_size(row_dimension); const int64 num_rows = in.dim_size(row_dimension);
const int64 num_cols = in.dim_size(col_dimension); const int64 num_cols = in.dim_size(col_dimension);
input_matrix_shapes->emplace_back(std::initializer_list<int64>({num_rows, num_cols})); input_matrix_shapes->emplace_back(
std::initializer_list<int64>({num_rows, num_cols}));
inputs->emplace_back(&in); inputs->emplace_back(&in);
} }
// Have the derived class validate that the inputs are as expected. // Have the derived class validate that the inputs are as expected.
@ -233,8 +234,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
matrix_inputs.emplace_back( matrix_inputs.emplace_back(
inputs[i]->flat<Scalar>().data() + inputs[i]->flat<Scalar>().data() +
matrix_index * input_matrix_shapes[i].num_elements(), matrix_index * input_matrix_shapes[i].num_elements(),
input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
input_matrix_shapes[i].dim_size(1));
} }
MatrixMaps matrix_outputs; MatrixMaps matrix_outputs;

View File

@ -1716,6 +1716,31 @@ op {
} }
} }
} }
op {
name: "Atan2"
input_arg {
name: "y"
type_attr: "T"
}
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "z"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
}
op { op {
name: "AudioSpectrogram" name: "AudioSpectrogram"
input_arg { input_arg {

View File

@ -1904,6 +1904,33 @@ op {
} }
summary: "Computes atan of x element-wise." summary: "Computes atan of x element-wise."
} }
op {
name: "Atan2"
input_arg {
name: "y"
type_attr: "T"
}
input_arg {
name: "x"
type_attr: "T"
}
output_arg {
name: "z"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Computes arctangent of `y/x` element-wise, respecting signs of the arguments."
description: "This is the angle \\( \\theta \\in [-\\pi, \\pi] \\) such that\n\\[ x = r \\cos(\\theta) \\]\nand\n\\[ y = r \\sin(\\theta) \\]\nwhere \\(r = \\sqrt(x^2 + y^2) \\)."
}
op { op {
name: "AudioSpectrogram" name: "AudioSpectrogram"
input_arg { input_arg {

View File

@ -1,33 +0,0 @@
# Random variable transformations (contrib)
[TOC]
Bijector Ops.
An API for invertible, differentiable transformations of random variables.
## Background
Differentiable, bijective transformations of continuous random variables alter
the calculations made in the cumulative/probability distribution functions and
sample function. This module provides a standard interface for making these
manipulations.
For more details and examples, see the `Bijector` docstring.
To apply a `Bijector`, use `distributions.TransformedDistribution`.
## Bijectors
* @{tf.contrib.distributions.bijector.Affine}
* @{tf.contrib.distributions.bijector.AffineLinearOperator}
* @{tf.contrib.distributions.bijector.Bijector}
* @{tf.contrib.distributions.bijector.Chain}
* @{tf.contrib.distributions.bijector.CholeskyOuterProduct}
* @{tf.contrib.distributions.bijector.Exp}
* @{tf.contrib.distributions.bijector.Identity}
* @{tf.contrib.distributions.bijector.Inline}
* @{tf.contrib.distributions.bijector.Invert}
* @{tf.contrib.distributions.bijector.PowerTransform}
* @{tf.contrib.distributions.bijector.SigmoidCentered}
* @{tf.contrib.distributions.bijector.SoftmaxCentered}
* @{tf.contrib.distributions.bijector.Softplus}

View File

@ -0,0 +1,33 @@
# Random variable transformations (contrib)
[TOC]
Bijector Ops.
An API for invertible, differentiable transformations of random variables.
## Background
Differentiable, bijective transformations of continuous random variables alter
the calculations made in the cumulative/probability distribution functions and
sample function. This module provides a standard interface for making these
manipulations.
For more details and examples, see the `Bijector` docstring.
To apply a `Bijector`, use `distributions.TransformedDistribution`.
## Bijectors
* @{tf.contrib.distributions.bijectors.Affine}
* @{tf.contrib.distributions.bijectors.AffineLinearOperator}
* @{tf.contrib.distributions.bijectors.Bijector}
* @{tf.contrib.distributions.bijectors.Chain}
* @{tf.contrib.distributions.bijectors.CholeskyOuterProduct}
* @{tf.contrib.distributions.bijectors.Exp}
* @{tf.contrib.distributions.bijectors.Identity}
* @{tf.contrib.distributions.bijectors.Inline}
* @{tf.contrib.distributions.bijectors.Invert}
* @{tf.contrib.distributions.bijectors.PowerTransform}
* @{tf.contrib.distributions.bijectors.SigmoidCentered}
* @{tf.contrib.distributions.bijectors.SoftmaxCentered}
* @{tf.contrib.distributions.bijectors.Softplus}

View File

@ -76,7 +76,7 @@ representing the posterior or posterior predictive.
## Kullback-Leibler Divergence ## Kullback-Leibler Divergence
* @{tf.contrib.distributions.kl} * @{tf.contrib.distributions.kl_divergence}
* @{tf.contrib.distributions.RegisterKL} * @{tf.contrib.distributions.RegisterKL}
## Utilities ## Utilities

View File

@ -40,7 +40,7 @@
* [Losses (contrib)](contrib.losses.md) * [Losses (contrib)](contrib.losses.md)
* [Metrics (contrib)](contrib.metrics.md) * [Metrics (contrib)](contrib.metrics.md)
* [Optimization (contrib)](contrib.opt.md) * [Optimization (contrib)](contrib.opt.md)
* [Random variable transformations (contrib)](contrib.distributions.bijector.md) * [Random variable transformations (contrib)](contrib.distributions.bijectors.md)
* [RNN and Cells (contrib)](contrib.rnn.md) * [RNN and Cells (contrib)](contrib.rnn.md)
* [Seq2seq Library (contrib)](contrib.seq2seq.md) * [Seq2seq Library (contrib)](contrib.seq2seq.md)
* [Staging (contrib)](contrib.staging.md) * [Staging (contrib)](contrib.staging.md)

View File

@ -80,10 +80,12 @@ section.
* **OS:** Ubuntu 16.04 LTS with tests run via Docker * **OS:** Ubuntu 16.04 LTS with tests run via Docker
* **CUDA / cuDNN:** 8.0 / 5.1 * **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e * **TensorFlow GitHub hash:** b1e174e
* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda * **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package` //tensorflow/tools/pip_package:build_pip_package`
* **Disk:** Local SSD * **Disk:** Local SSD
* **DataSet:** ImageNet * **DataSet:** ImageNet
* **Test Date:** May 2017
Batch size and optimizer used for each model are listed in the table below. In Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3, ResNet-50, addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
@ -120,19 +122,19 @@ VGG16 | replicated (with NCCL) | n/a
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 142 | 238 | 95.6 | 2987 | 154 1 | 142 | 219 | 91.8 | 2987 | 154
2 | 284 | 479 | 187 | 5658 | 295 2 | 284 | 422 | 181 | 5658 | 295
4 | 569 | 948 | 374 | 10509 | 584 4 | 569 | 852 | 356 | 10509 | 584
8 | 1131 | 1886 | 744 | 17822 | 1081 8 | 1131 | 1734 | 716 | 17822 | 1081
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 142 | 239 | 95.5 | 2890 | 154 1 | 142 | 218 | 91.4 | 2890 | 154
2 | 278 | 468 | 187 | 4448 | 284 2 | 278 | 425 | 179 | 4448 | 284
4 | 551 | 938 | 373 | 7105 | 534 4 | 551 | 853 | 359 | 7105 | 534
8 | 1079 | 1802 | 721 | N/A | 898 8 | 1079 | 1630 | 708 | N/A | 898
Training AlexNet with real data on 8 GPUs was excluded from the graph and table Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to it maxing out the input pipeline. above due to it maxing out the input pipeline.
@ -145,19 +147,19 @@ The results below are all with a batch size of 32.
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | ----- ---- | ----------- | --------- | ---------- | -----
1 | 128 | 210 | 85.3 | 144 1 | 128 | 195 | 82.7 | 144
2 | 259 | 412 | 166 | 281 2 | 259 | 368 | 160 | 281
4 | 520 | 827 | 330 | 549 4 | 520 | 768 | 317 | 549
8 | 995 | 1623 | 643 | 820 8 | 995 | 1485 | 632 | 820
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | ----- ---- | ----------- | --------- | ---------- | -----
1 | 130 | 208 | 85.0 | 144 1 | 130 | 193 | 82.4 | 144
2 | 257 | 403 | 163 | 253 2 | 257 | 369 | 159 | 253
4 | 507 | 814 | 325 | 457 4 | 507 | 760 | 317 | 457
8 | 966 | 1525 | 641 | 690 8 | 966 | 1410 | 609 | 690
## Details for Google Compute Engine (NVIDIA® Tesla® K80) ## Details for Google Compute Engine (NVIDIA® Tesla® K80)
@ -168,11 +170,12 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
* **OS:** Ubuntu 16.04 LTS * **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1 * **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e * **TensorFlow GitHub hash:** b1e174e
* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda * **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package` //tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s) * **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
* **DataSet:** ImageNet * **DataSet:** ImageNet
* **Test Date:** April 2017 * **Test Date:** May 2017
Batch size and optimizer used for each model are listed in the table below. In Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
@ -198,19 +201,19 @@ The configuration used for each model was `variable_update` equal to
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.8 | 20.8 | 656 | 35.4 1 | 30.5 | 51.9 | 20.0 | 656 | 35.4
2 | 57.8 | 107 | 39.1 | 1209 | 64.8 2 | 57.8 | 99.0 | 38.2 | 1209 | 64.8
4 | 116 | 212 | 77.2 | 2328 | 120 4 | 116 | 195 | 75.8 | 2328 | 120
8 | 227 | 419 | 151 | 4640 | 234 8 | 227 | 387 | 148 | 4640 | 234
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.6 | 56.7 | 20.7 | 639 | 34.2 1 | 30.6 | 51.2 | 20.0 | 639 | 34.2
2 | 58.4 | 107 | 39.0 | 1136 | 62.9 2 | 58.4 | 98.8 | 38.3 | 1136 | 62.9
4 | 115 | 211 | 77.3 | 2067 | 118 4 | 115 | 194 | 75.4 | 2067 | 118
8 | 225 | 422 | 151 | 4056 | 230 8 | 225 | 381 | 148 | 4056 | 230
### Other Results ### Other Results
@ -218,19 +221,19 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 29.3 | 53.9 1 | 29.3 | 49.5
2 | 55.0 | 101 2 | 55.0 | 95.4
4 | 109 | 200 4 | 109 | 183
8 | 216 | 398 8 | 216 | 362
**Training real data** **Training real data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 29.5 | 53.6 1 | 29.5 | 49.3
2 | 55.4 | 102 2 | 55.4 | 95.3
4 | 110 | 201 4 | 110 | 186
8 | 216 | 387 8 | 216 | 359
## Details for Amazon EC2 (NVIDIA® Tesla® K80) ## Details for Amazon EC2 (NVIDIA® Tesla® K80)
@ -241,12 +244,13 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
* **OS:** Ubuntu 16.04 LTS * **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1 * **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e * **TensorFlow GitHub hash:** b1e174e
* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda * **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package` //tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50 * **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
MiB/sec) MiB/sec)
* **DataSet:** ImageNet * **DataSet:** ImageNet
* **Test Date:** April 2017 * **Test Date:** May 2017
Batch size and optimizer used for each model are listed in the table below. In Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
@ -279,19 +283,19 @@ VGG16 | parameter_server | gpu
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.8 | 56.3 | 20.9 | 684 | 36.3 1 | 30.8 | 51.5 | 19.7 | 684 | 36.3
2 | 58.7 | 108 | 39.3 | 1244 | 69.4 2 | 58.7 | 98.0 | 37.6 | 1244 | 69.4
4 | 117 | 217 | 79.1 | 2479 | 141 4 | 117 | 195 | 74.9 | 2479 | 141
8 | 230 | 419 | 156 | 4853 | 260 8 | 230 | 384 | 149 | 4853 | 260
**Training real data** **Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16 GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | ----- ---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.0 | 20.6 | 674 | 36.3 1 | 30.5 | 51.3 | 19.7 | 674 | 36.3
2 | 59.0 | 107 | 39.0 | 1227 | 67.5 2 | 59.0 | 94.9 | 38.2 | 1227 | 67.5
4 | 118 | 205 | 77.9 | 2201 | 136 4 | 118 | 188 | 75.2 | 2201 | 136
8 | 228 | 405 | 152 | N/A | 242 8 | 228 | 373 | 149 | N/A | 242
Training AlexNet with real data on 8 GPUs was excluded from the graph and table Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to our EFS setup not providing enough throughput. above due to our EFS setup not providing enough throughput.
@ -302,19 +306,19 @@ above due to our EFS setup not providing enough throughput.
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 29.9 | 53.5 1 | 29.9 | 49.0
2 | 57.5 | 101 2 | 57.5 | 94.1
4 | 114 | 202 4 | 114 | 184
8 | 216 | 380 8 | 216 | 355
**Training real data** **Training real data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 30.0 | 53.6 1 | 30.0 | 49.1
2 | 57.5 | 102 2 | 57.5 | 95.1
4 | 113 | 202 4 | 113 | 185
8 | 212 | 379 8 | 212 | 353
## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80) ## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
@ -325,11 +329,12 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
* **OS:** Ubuntu 16.04 LTS * **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1 * **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e * **TensorFlow GitHub hash:** b1e174e
* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda * **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package` //tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec) * **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
* **DataSet:** ImageNet * **DataSet:** ImageNet
* **Test Date:** April 2017 * **Test Date:** May 2017
The batch size and optimizer used for the tests are listed in the table. In The batch size and optimizer used for the tests are listed in the table. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
@ -343,11 +348,11 @@ Optimizer | sgd | sgd | sgd
Configuration used for each model. Configuration used for each model.
Model | variable_update | local_parameter_device Model | variable_update | local_parameter_device | cross_replica_sync
----------- | ---------------------- | ---------------------- ----------- | ---------------------- | ---------------------- | ------------------
InceptionV3 | distributed_replicated | n/a InceptionV3 | distributed_replicated | n/a | True
ResNet-50 | distributed_replicated | n/a ResNet-50 | distributed_replicated | n/a | True
ResNet-152 | distributed_replicated | n/a ResNet-152 | distributed_replicated | n/a | True
To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
ran parameter servers. Equal numbers of parameter servers and work servers were ran parameter servers. Equal numbers of parameter servers and work servers were
@ -371,11 +376,11 @@ used with the following exceptions:
GPUs | InceptionV3 | ResNet-50 | ResNet-152 GPUs | InceptionV3 | ResNet-50 | ResNet-152
---- | ----------- | --------- | ---------- ---- | ----------- | --------- | ----------
1 | 29.7 | 55.0 | 19.8 1 | 29.7 | 52.4 | 19.4
8 | 229 | 410 | 150 8 | 229 | 378 | 146
16 | 459 | 825 | 300 16 | 459 | 751 | 291
32 | 902 | 1468 | 575 32 | 902 | 1388 | 565
64 | 1783 | 3051 | 1004 64 | 1783 | 2744 | 981
### Other Results ### Other Results
@ -387,23 +392,23 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | ------------------------- ---- | --------------------------- | -------------------------
1 | 29.2 | 53.0 1 | 29.2 | 48.4
8 | 219 | 363 8 | 219 | 333
16 | 427 | 719 16 | 427 | 667
32 | 820 | 1265 32 | 820 | 1180
64 | 1608 | 2623 64 | 1608 | 2315
## Methodology ## Methodology
This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks) This
[script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
was run on the various platforms to generate the above results. was run on the various platforms to generate the above results.
@{$performance_models$High-Performance Models} details techniques in the script @{$performance_models$High-Performance Models} details techniques in the script
along with examples of how to execute the script. along with examples of how to execute the script.
In order to create results that are as repeatable as possible, each test was run In order to create results that are as repeatable as possible, each test was run
5 times and then the times were averaged together. GPUs are run in their default 5 times and then the times were averaged together. GPUs are run in their default
state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/). Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
For each test, 10 warmup steps are done and then the next 100 steps are For each test, 10 warmup steps are done and then the next 100 steps are
averaged. averaged.

View File

@ -370,9 +370,8 @@ def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
tf.logging.fatal('File does not exist %s', image_path) tf.logging.fatal('File does not exist %s', image_path)
image_data = gfile.FastGFile(image_path, 'rb').read() image_data = gfile.FastGFile(image_path, 'rb').read()
try: try:
bottleneck_values = run_bottleneck_on_image(sess, image_data, bottleneck_values = run_bottleneck_on_image(
jpeg_data_tensor, sess, image_data, jpeg_data_tensor, bottleneck_tensor)
bottleneck_tensor)
except: except:
raise RuntimeError('Error during processing file %s' % image_path) raise RuntimeError('Error during processing file %s' % image_path)

View File

@ -5583,6 +5583,74 @@ func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
return op.Output(0) return op.Output(0)
} }
// Store the input tensor in the state of the current session.
//
// Arguments:
// value: The tensor to be stored.
//
// Returns The handle for the tensor stored in the session state, represented
// as a ResourceHandle object.
func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "GetSessionHandleV2",
Input: []tf.Input{
value,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Adjust the hue of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last dimension is
// interpretted as channels, and must be three.
//
// The input image is considered in the RGB colorspace. Conceptually, the RGB
// colors are first mapped into HSV. A delta is then applied all the hue values,
// and then remapped back to RGB colorspace.
//
// Arguments:
// images: Images to adjust. At least 3-D.
// delta: A float delta to add to the hue.
//
// Returns The hue-adjusted image or images.
func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "AdjustHue",
Input: []tf.Input{
images, delta,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Restore a Reader to its initial clean state.
//
// Arguments:
// reader_handle: Handle to a Reader.
//
// Returns the created operation.
func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "ReaderResetV2",
Input: []tf.Input{
reader_handle,
},
}
return scope.AddOperation(opspec)
}
// Computes softmax cross entropy cost and gradients to backpropagate. // Computes softmax cross entropy cost and gradients to backpropagate.
// //
// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept // Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
@ -19039,6 +19107,27 @@ func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
return op.Output(0) return op.Output(0)
} }
// Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
//
// This is the angle \( \theta \in [-\pi, \pi] \) such that
// \[ x = r \cos(\theta) \]
// and
// \[ y = r \sin(\theta) \]
// where \(r = \sqrt(x^2 + y^2) \).
func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "Atan2",
Input: []tf.Input{
y, x,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Compute the regularized incomplete beta integral \\(I_x(a, b)\\). // Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
// //
// The regularized incomplete beta integral is defined as: // The regularized incomplete beta integral is defined as:
@ -21627,71 +21716,3 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O
op := scope.AddOperation(opspec) op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1) return op.Output(0), op.Output(1)
} }
// Store the input tensor in the state of the current session.
//
// Arguments:
// value: The tensor to be stored.
//
// Returns The handle for the tensor stored in the session state, represented
// as a ResourceHandle object.
func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "GetSessionHandleV2",
Input: []tf.Input{
value,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Adjust the hue of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last dimension is
// interpretted as channels, and must be three.
//
// The input image is considered in the RGB colorspace. Conceptually, the RGB
// colors are first mapped into HSV. A delta is then applied all the hue values,
// and then remapped back to RGB colorspace.
//
// Arguments:
// images: Images to adjust. At least 3-D.
// delta: A float delta to add to the hue.
//
// Returns The hue-adjusted image or images.
func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "AdjustHue",
Input: []tf.Input{
images, delta,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Restore a Reader to its initial clean state.
//
// Arguments:
// reader_handle: Handle to a Reader.
//
// Returns the created operation.
func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "ReaderResetV2",
Input: []tf.Input{
reader_handle,
},
}
return scope.AddOperation(opspec)
}

View File

@ -1 +0,0 @@
#include "unsupported/Eigen/CXX11/ThreadPool"

View File

@ -436,13 +436,13 @@ class EstimatorTrainTest(test.TestCase):
model_dir=model_dir1, model_dir=model_dir1,
model_fn=model_fn_global_step_incrementer) model_fn=model_fn_global_step_incrementer)
est1.train(dummy_input_fn, steps=5) est1.train(dummy_input_fn, steps=5)
# We have to clear the cache before we can rename the directory, # We have to clear the cache before we can rename the directory,
# otherwise open file handles will prevent the delete on Windows. # otherwise open file handles will prevent the delete on Windows.
writer_cache.FileWriterCache.clear() writer_cache.FileWriterCache.clear()
model_dir2 = os.path.join(tmpdir, 'model_dir2') model_dir2 = os.path.join(tmpdir, 'model_dir2')
os.renames(model_dir1, model_dir2) os.renames(model_dir1, model_dir2)
est2 = estimator.Estimator( est2 = estimator.Estimator(
model_dir=model_dir2, model_dir=model_dir2,
model_fn=model_fn_global_step_incrementer) model_fn=model_fn_global_step_incrementer)

View File

@ -129,6 +129,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -656,6 +657,44 @@ def categorical_column_with_vocabulary_list(
default_value=default_value) default_value=default_value)
def categorical_column_with_identity(key, num_buckets, default_value=None):
"""A `_CategoricalColumn` that returns identity values.
Use this when your inputs are integers in the range `[0, num_buckets)`. Values
outside this range will result in `default_value` if specified, otherwise it
will fail.
Inputs can be either `Tensor` or `SparseTensor`.
```
Args:
key: A unique string identifying the input feature. It is used as the
column name and the dictionary key for feature parsing configs, feature
`Tensor` objects, and feature columns.
num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
default_value: If `None`, this column's graph operations will fail for
out-of-range inputs. Otherwise, this value must be in the range
`[0, num_buckets)`, and will replace inputs in that range.
Returns:
A `_CategoricalColumn` that returns identity values.
Raises:
ValueError: if `num_buckets` is less than one.
ValueError: if `default_value` is not in range `[0, num_buckets)`.
"""
if num_buckets < 1:
raise ValueError(
'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
if (default_value is not None) and (
(default_value < 0) or (default_value >= num_buckets)):
raise ValueError(
'default_value {} not in range [0, {}), column_name {}'.format(
default_value, num_buckets, key))
return _IdentityCategoricalColumn(
key=key, num_buckets=num_buckets, default_value=default_value)
class _FeatureColumn(object): class _FeatureColumn(object):
"""Represents a feature column abstraction. """Represents a feature column abstraction.
@ -1384,6 +1423,69 @@ class _VocabularyListCategoricalColumn(
return _CategoricalColumn.IdWeightPair(inputs.get(self), None) return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
class _IdentityCategoricalColumn(
_CategoricalColumn,
collections.namedtuple('_IdentityCategoricalColumn', (
'key', 'num_buckets', 'default_value'
))):
"""See `categorical_column_with_identity`."""
@property
def name(self):
return self.key
@property
def _parse_example_config(self):
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
def _transform_feature(self, inputs):
input_tensor = _to_sparse_input(inputs.get(self.key))
if not input_tensor.dtype.is_integer:
raise ValueError(
'Invalid input, not integer. key: {} dtype: {}'.format(
self.key, input_tensor.dtype))
values = math_ops.to_int64(input_tensor.values, name='values')
num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
zero = math_ops.to_int64(0, name='zero')
if self.default_value is None:
# Fail if values are out-of-range.
assert_less = check_ops.assert_less(
values, num_buckets, data=(values, num_buckets),
name='assert_less_than_num_buckets')
assert_greater = check_ops.assert_greater_equal(
values, zero, data=(values,),
name='assert_greater_or_equal_0')
with ops.control_dependencies((assert_less, assert_greater)):
values = array_ops.identity(values)
else:
# Assign default for out-of-range values.
values = array_ops.where(
math_ops.logical_or(
values < zero, values >= num_buckets, name='out_of_range'),
array_ops.fill(
dims=array_ops.shape(values),
value=math_ops.to_int64(self.default_value),
name='default_values'),
values)
return sparse_tensor_lib.SparseTensor(
indices=input_tensor.indices,
values=values,
dense_shape=input_tensor.dense_shape)
@property
def _num_buckets(self):
"""Returns number of buckets in this sparse feature."""
return self.num_buckets
def _get_sparse_tensors(
self, inputs, weight_collections=None, trainable=None):
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
# TODO(zakaria): Move this to embedding_ops and make it public. # TODO(zakaria): Move this to embedding_ops and make it public.
def _safe_embedding_lookup_sparse(embedding_weights, def _safe_embedding_lookup_sparse(embedding_weights,
sparse_ids, sparse_ids,

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
@ -1828,5 +1829,198 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertAllClose(((3.,), (1.,)), predictions.eval()) self.assertAllClose(((3.,), (1.,)), predictions.eval())
class IdentityCategoricalColumnTest(test.TestCase):
def test_constructor(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual('aaa', column.name)
# pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_config)
# pylint: enable=protected-access
def test_deep_copy(self):
"""Tests deepcopy of categorical_column_with_hash_bucket."""
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
for column in (original, copy.deepcopy(original)):
self.assertEqual('aaa', column.name)
# pylint: disable=protected-access
self.assertEqual(3, column._num_buckets)
self.assertEqual({
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_config)
# pylint: enable=protected-access
def test_invalid_num_buckets_zero(self):
with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
fc.categorical_column_with_identity(key='aaa', num_buckets=0)
def test_invalid_num_buckets_negative(self):
with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
def test_invalid_default_value_too_small(self):
with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
fc.categorical_column_with_identity(
key='aaa', num_buckets=3, default_value=-1)
def test_invalid_default_value_too_big(self):
with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
fc.categorical_column_with_identity(
key='aaa', num_buckets=3, default_value=3)
def test_invalid_input_dtype(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('omar', 'stringer', 'marlo'),
dense_shape=(2, 2))
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
# pylint: disable=protected-access
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
# pylint: enable=protected-access
def test_get_sparse_tensors(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 1, 0),
dense_shape=(2, 2))
# pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
# pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
indices=inputs.indices,
values=np.array((0, 1, 0), dtype=np.int64),
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
# pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
'aaa': ((0, -1), (1, 0))
}))
# pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=np.array((0, 1, 0), dtype=np.int64),
dense_shape=(2, 2)),
id_weight_pair.id_tensor.eval())
def test_get_sparse_tensors_with_inputs_too_small(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(1, -1, 0),
dense_shape=(2, 2))
# pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
# pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with self.assertRaisesRegexp(
errors.OpError, 'assert_greater_or_equal_0'):
id_weight_pair.id_tensor.eval()
def test_get_sparse_tensors_with_inputs_too_big(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(1, 99, 0),
dense_shape=(2, 2))
# pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
# pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
with self.assertRaisesRegexp(
errors.OpError, 'assert_less_than_num_buckets'):
id_weight_pair.id_tensor.eval()
def test_get_sparse_tensors_with_default_value(self):
column = fc.categorical_column_with_identity(
key='aaa', num_buckets=4, default_value=3)
inputs = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(1, -1, 99),
dense_shape=(2, 2))
# pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
# pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
indices=inputs.indices,
values=np.array((1, 3, 3), dtype=np.int64),
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
column = fc.categorical_column_with_identity(
key='aaa', num_buckets=4, default_value=3)
input_indices = array_ops.placeholder(dtype=dtypes.int64)
input_values = array_ops.placeholder(dtype=dtypes.int32)
input_shape = array_ops.placeholder(dtype=dtypes.int64)
inputs = sparse_tensor.SparseTensorValue(
indices=input_indices,
values=input_values,
dense_shape=input_shape)
# pylint: disable=protected-access
id_weight_pair = column._get_sparse_tensors(
fc._LazyBuilder({'aaa': inputs}))
# pylint: enable=protected-access
self.assertIsNone(id_weight_pair.weight_tensor)
with _initialized_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
values=np.array((1, 3, 3), dtype=np.int64),
dense_shape=np.array((2, 2), dtype=np.int64)),
id_weight_pair.id_tensor.eval(feed_dict={
input_indices: ((0, 0), (1, 0), (1, 1)),
input_values: (1, -1, 99),
input_shape: (2, 2),
}))
def test_make_linear_model(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual(3, column._num_buckets)
with ops.Graph().as_default():
predictions = fc.make_linear_model({
column.name: sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2))
}, (column,))
bias = get_linear_model_bias()
weight_var = get_linear_model_column_var(column)
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
self.assertAllClose(((0.,), (0.,)), predictions.eval())
weight_var.assign(((1.,), (2.,), (3.,))).eval()
# weight_var[0] = 1
# weight_var[2] + weight_var[1] = 3+2 = 5
self.assertAllClose(((1.,), (5.,)), predictions.eval())
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -113,18 +113,19 @@ def _add_op_node(op, func, input_dict):
node_def = func.node_def[-1] node_def = func.node_def[-1]
for i in range(len(node_def.input)): for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"): if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, ( assert node_def.input[i] in input_dict, ("%s missing from %s" %
"%s missing from %s" % (node_def.input[i], input_dict.items())) (node_def.input[i],
input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]] node_def.input[i] = input_dict[node_def.input[i]]
def _graph_to_function_def(graph, inputs, outputs, out_names=None): def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer. """Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`]( This method creates a [`FunctionDef`](
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
protocol buffer that contains all the ops present in the graph. The protocol buffer that contains all the ops in `operations`. The
graph effectively becomes the body of the function. operations become the body of the function.
The arguments `inputs` and `outputs` will be listed as the inputs The arguments `inputs` and `outputs` will be listed as the inputs
and outputs tensors of the function. They must be lists of and outputs tensors of the function. They must be lists of
@ -132,6 +133,8 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
Args: Args:
graph: Graph. graph: Graph.
operations: the operations to put in the function. Must be a subset of
the operations in the graph.
inputs: List of tensors. Inputs to the function. inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function. outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs. out_names: Optional list of string names for the outputs.
@ -145,12 +148,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
func = function_pb2.FunctionDef() func = function_pb2.FunctionDef()
func.signature.name = "_" func.signature.name = "_"
used_names = set() used_names = set()
func.signature.input_arg.extend([_tensor_to_argdef(i, used_names=used_names) func.signature.input_arg.extend(
for i in inputs]) [_tensor_to_argdef(i, used_names=used_names) for i in inputs])
if out_names is None: if out_names is None:
used_names = set() used_names = set()
func.signature.output_arg.extend([ func.signature.output_arg.extend(
_tensor_to_argdef(o, used_names=used_names) for o in outputs]) [_tensor_to_argdef(o, used_names=used_names) for o in outputs])
elif len(outputs) != len(out_names): elif len(outputs) != len(out_names):
raise ValueError( raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" % "Length of out_names (%d) does not match number of outputs (%d): %s" %
@ -159,12 +162,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
raise ValueError( raise ValueError(
"Must not have duplicates in out_names: %s" % ", ".join(out_names)) "Must not have duplicates in out_names: %s" % ", ".join(out_names))
else: else:
func.signature.output_arg.extend([ func.signature.output_arg.extend(
_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)]) [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs]) func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders) input_dict = _create_input_dict(graph, func_arg_placeholders)
for op in graph.get_operations(): for op in operations:
if _is_in_placeholders(op, func_arg_placeholders): if _is_in_placeholders(op, func_arg_placeholders):
continue continue
_add_op_node(op, func, input_dict) _add_op_node(op, func, input_dict)
@ -295,17 +298,18 @@ class _FuncGraph(ops.Graph):
self.extra_args = [] self.extra_args = []
self.extra_vars = [] self.extra_vars = []
def getvar(self, def getvar(
getter, self,
name, getter,
shape=None, name,
dtype=None, shape=None,
initializer=None, dtype=None,
reuse=None, initializer=None,
trainable=True, reuse=None,
collections=None, # pylint: disable=redefined-outer-name trainable=True,
use_resource=None, collections=None, # pylint: disable=redefined-outer-name
**kwargs): use_resource=None,
**kwargs):
"""A custom variable getter.""" """A custom variable getter."""
# Here, we switch the default graph to the outer graph and ask the # Here, we switch the default graph to the outer graph and ask the
# variable scope in which the function is defined to give us the # variable scope in which the function is defined to give us the
@ -538,20 +542,23 @@ class _DefinedFunction(object):
# Build the FunctionDef # Build the FunctionDef
self._definition = _graph_to_function_def( self._definition = _graph_to_function_def(
temp_graph, inputs, outputs, out_names=self._out_names) temp_graph,
temp_graph.get_operations(),
inputs,
outputs,
out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def. # Extra kwargs are treated as attrs on the function def.
sig_pre_func_name = self._func_name or _get_func_name(self._func) sig_pre_func_name = self._func_name or _get_func_name(self._func)
kwargs_attr = _parse_kwargs_as_attrs( kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
sig_pre_func_name, **self._extra_kwargs) **self._extra_kwargs)
for k in kwargs_attr: for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k]) self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies. # Hash the definition and its dependencies.
self._hash_str = self._create_hash_str( self._hash_str = self._create_hash_str(
self._definition.signature.input_arg, self._definition.signature.input_arg,
self._definition.signature.output_arg, self._definition.signature.output_arg, self._definition.node_def)
self._definition.node_def)
# Finally, we decide the function name to use. If not specified, # Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic). # make up something which is almost certainly unique (but deterministic).
@ -658,8 +665,8 @@ def _from_definition(fdef, grad_func=None):
# have access to such a callable here). # have access to such a callable here).
func = None func = None
argnames = [arg.name for arg in fdef.signature.input_arg] argnames = [arg.name for arg in fdef.signature.input_arg]
input_types = tuple(dtypes.as_dtype(arg.type) input_types = tuple(
for arg in fdef.signature.input_arg) dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
func_name = fdef.signature.name func_name = fdef.signature.name
# Note: FunctionDefs do not include python gradient functions, so if the # Note: FunctionDefs do not include python gradient functions, so if the
# original _DefinedFunction included one it will not be reflected here. # original _DefinedFunction included one it will not be reflected here.
@ -675,8 +682,7 @@ def _from_definition(fdef, grad_func=None):
result._extra_inputs = [] result._extra_inputs = []
result._hash_str = result._create_hash_str( result._hash_str = result._create_hash_str(
result._definition.signature.input_arg, result._definition.signature.input_arg,
result._definition.signature.output_arg, result._definition.signature.output_arg, result._definition.node_def)
result._definition.node_def)
# pylint: enable=protected-access # pylint: enable=protected-access
return result return result
@ -696,7 +702,8 @@ def _from_library(lib):
Raises: Raises:
ValueError: `lib` is invalid ValueError: `lib` is invalid
""" """
if not lib.function and not lib.gradient: return [] if not lib.function and not lib.gradient:
return []
# function name -> FunctionDef proto # function name -> FunctionDef proto
funcs = {fdef.signature.name: fdef for fdef in lib.function} funcs = {fdef.signature.name: fdef for fdef in lib.function}
@ -720,8 +727,9 @@ def _from_library(lib):
grad_to_funcs[gdef.gradient_func].append(gdef.function_name) grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
# Start with functions without gradients # Start with functions without gradients
ready = [fdef for fdef in lib.function ready = [
if func_to_grad[fdef.signature.name] is None] fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
]
if not ready: if not ready:
raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n" raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
+ str(lib)) + str(lib))
@ -733,7 +741,8 @@ def _from_library(lib):
name = fdef.signature.name name = fdef.signature.name
grad = initialized.get(func_to_grad[name]) grad = initialized.get(func_to_grad[name])
if func_to_grad[name]: assert grad if func_to_grad[name]:
assert grad
defined_func = _from_definition(fdef, grad_func=grad) defined_func = _from_definition(fdef, grad_func=grad)
initialized[name] = defined_func initialized[name] = defined_func
@ -835,10 +844,15 @@ class _OverloadedFunction(object):
name = self._func_name name = self._func_name
if name is not None: if name is not None:
name = "_".join([name, key]) name = "_".join([name, key])
defined = _DefinedFunction(self._func, self._argnames, input_types, name, defined = _DefinedFunction(
None, self._python_grad_func, self._func,
out_names=self._out_names, self._argnames,
**self._extra_kwargs) input_types,
name,
None,
self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs)
_ = defined.name # Fully instantiate the function definition. _ = defined.name # Fully instantiate the function definition.
if self._grad_func: if self._grad_func:
# If _grad_func is given, it is another # If _grad_func is given, it is another
@ -849,8 +863,8 @@ class _OverloadedFunction(object):
for _ in defined.definition.signature.output_arg for _ in defined.definition.signature.output_arg
] ]
# pylint: disable=protected-access # pylint: disable=protected-access
defined._grad_func = self._grad_func.instantiate(input_types + defined._grad_func = self._grad_func.instantiate(
output_types) input_types + output_types)
# pylint: enable=protected-access # pylint: enable=protected-access
self._overload[key] = defined self._overload[key] = defined
return defined return defined
@ -981,22 +995,36 @@ class Defun(object):
raise ValueError( raise ValueError(
"The function has fewer arguments than the number of specified " "The function has fewer arguments than the number of specified "
"input types.") "input types.")
return _DefinedFunction(func, argnames, self._input_types, return _DefinedFunction(
self._func_name, self._grad_func, func,
self._python_grad_func, argnames,
out_names=self._out_names, **self._extra_kwargs) self._input_types,
self._func_name,
self._grad_func,
self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs)
# 'func' expects no arguments and input types is an empty list. # 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0: if min_args == 0 and max_args == 0:
return _DefinedFunction(func, [], [], self._func_name, self._grad_func, return _DefinedFunction(
self._python_grad_func, func, [], [],
out_names=self._out_names, **self._extra_kwargs) self._func_name,
self._grad_func,
self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs)
# Input types are unknown. It's an overloaded function and hence # Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called. # its definition needs to be deferred until it's called.
return _OverloadedFunction(func, argnames, self._func_name, self._grad_func, return _OverloadedFunction(
self._python_grad_func, func,
out_names=self._out_names, **self._extra_kwargs) argnames,
self._func_name,
self._grad_func,
self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs)
class Declare(object): class Declare(object):
@ -1039,8 +1067,10 @@ class Declare(object):
names = [n for n, t in args] names = [n for n, t in args]
if len(names) != len(set(names)): if len(names) != len(set(names)):
raise ValueError("Expected names to all be unique: %s" % str(names)) raise ValueError("Expected names to all be unique: %s" % str(names))
return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n) return [
for n, t in args] op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
for n, t in args
]
self._sig.input_arg.extend(_to_argdef_list(inputs)) self._sig.input_arg.extend(_to_argdef_list(inputs))
self._sig.output_arg.extend(_to_argdef_list(outputs)) self._sig.output_arg.extend(_to_argdef_list(outputs))

View File

@ -1106,16 +1106,18 @@ class BinaryOpTest(test.TestCase):
def testAtan2SpecialValues(self): def testAtan2SpecialValues(self):
x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0), x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
(1.2345, float('inf')), (1.2345, -float('inf')), (1.2345, float("inf")), (1.2345, -float("inf")),
(-4.321, float('inf')), (-4.125, -float('inf')), (-4.321, float("inf")), (-4.125, -float("inf")),
(float('inf'), float('inf')), (float('inf'), -float('inf')), (float("inf"), float("inf")), (float("inf"), -float("inf")),
(-float('inf'), float('inf')), (-float('inf'), -float('inf'))) (-float("inf"), float("inf")), (-float("inf"),
-float("inf")))
for dtype in np.float32, np.float64: for dtype in np.float32, np.float64:
x1 = np.array(x1l).astype(dtype) x1 = np.array(x1l).astype(dtype)
x2 = np.array(x2l).astype(dtype) x2 = np.array(x2l).astype(dtype)
self._compareCpu(x1, x2, np.arctan2, math_ops.atan2) self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
self._compareGpu(x1, x2, np.arctan2, math_ops.atan2) self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
class ComparisonOpTest(test.TestCase): class ComparisonOpTest(test.TestCase):
def _compareScalar(self, func, x, y, dtype): def _compareScalar(self, func, x, y, dtype):

View File

@ -19,58 +19,65 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow import Tensor
from tensorflow import register_tensor_conversion_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test as test_lib from tensorflow.python.platform import test as test_lib
class TensorPriorityTest(test_lib.TestCase): class TensorPriorityTest(test_lib.TestCase):
def testSupportedRhsWithoutDelegation(self): def testSupportedRhsWithoutDelegation(self):
class NumpyArraySubclass(np.ndarray): class NumpyArraySubclass(np.ndarray):
pass pass
supported_rhs_without_delegation = (
3, supported_rhs_without_delegation = (3, 3.0, [1.0, 2.0], np.array(
3.0, [1.0, 2.0]), NumpyArraySubclass(
[1.0, 2.0], shape=(1, 2), buffer=np.array([1.0, 2.0])),
np.array([1.0, 2.0]), ops.convert_to_tensor([[1.0, 2.0]]))
NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0])),
ops.convert_to_tensor([[1.0, 2.0]]))
for rhs in supported_rhs_without_delegation: for rhs in supported_rhs_without_delegation:
tensor = ops.convert_to_tensor([[10.0, 20.0]]) tensor = ops.convert_to_tensor([[10.0, 20.0]])
res = tensor + rhs res = tensor + rhs
self.assertIsInstance(res, Tensor) self.assertIsInstance(res, ops.Tensor)
def testUnsupportedRhsWithoutDelegation(self): def testUnsupportedRhsWithoutDelegation(self):
class WithoutReverseAdd(object): class WithoutReverseAdd(object):
pass pass
tensor = ops.convert_to_tensor([[10.0, 20.0]]) tensor = ops.convert_to_tensor([[10.0, 20.0]])
rhs = WithoutReverseAdd() rhs = WithoutReverseAdd()
with self.assertRaisesWithPredicateMatch( with self.assertRaisesWithPredicateMatch(
TypeError, lambda e: "Expected float" in str(e)): TypeError, lambda e: "Expected float" in str(e)):
res = tensor + rhs # pylint: disable=pointless-statement
tensor + rhs
def testUnsupportedRhsWithDelegation(self): def testUnsupportedRhsWithDelegation(self):
class WithReverseAdd(object): class WithReverseAdd(object):
def __radd__(self, lhs): def __radd__(self, lhs):
return "Works!" return "Works!"
tensor = ops.convert_to_tensor([[10.0, 20.0]]) tensor = ops.convert_to_tensor([[10.0, 20.0]])
rhs = WithReverseAdd() rhs = WithReverseAdd()
res = tensor + rhs res = tensor + rhs
self.assertEqual(res, "Works!") self.assertEqual(res, "Works!")
def testFullDelegationControlUsingRegistry(self): def testFullDelegationControlUsingRegistry(self):
class NumpyArraySubclass(np.ndarray): class NumpyArraySubclass(np.ndarray):
def __radd__(self, lhs): def __radd__(self, lhs):
return "Works!" return "Works!"
def raise_to_delegate(value, dtype=None, name=None, as_ref=False): def raise_to_delegate(value, dtype=None, name=None, as_ref=False):
del value, dtype, name, as_ref # Unused.
raise TypeError raise TypeError
register_tensor_conversion_function(NumpyArraySubclass, raise_to_delegate,
priority=0) ops.register_tensor_conversion_function(
NumpyArraySubclass, raise_to_delegate, priority=0)
tensor = ops.convert_to_tensor([[10.0, 20.0]]) tensor = ops.convert_to_tensor([[10.0, 20.0]])
rhs = NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0])) rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0]))
res = tensor + rhs res = tensor + rhs
self.assertEqual(res, "Works!") self.assertEqual(res, "Works!")

View File

@ -1109,10 +1109,10 @@ class Conv2DTranspose(Conv2D):
# Infer the static output shape: # Infer the static output shape:
out_shape = inputs.get_shape().as_list() out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters out_shape[c_axis] = self.filters
out_shape[h_axis] = utils.get_deconv_dim( out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
out_shape[h_axis], stride_h, kernel_h, self.padding) kernel_h, self.padding)
out_shape[w_axis] = utils.get_deconv_dim( out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
out_shape[w_axis], stride_w, kernel_w, self.padding) kernel_w, self.padding)
outputs.set_shape(out_shape) outputs.set_shape(out_shape)
if self.bias: if self.bias:
@ -1240,7 +1240,8 @@ class Conv3DTranspose(Conv3D):
name: A string, the name of the layer. name: A string, the name of the layer.
""" """
def __init__(self, filters, def __init__(self,
filters,
kernel_size, kernel_size,
strides=(1, 1, 1), strides=(1, 1, 1),
padding='valid', padding='valid',
@ -1269,12 +1270,13 @@ class Conv3DTranspose(Conv3D):
bias_regularizer=bias_regularizer, bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer, activity_regularizer=activity_regularizer,
trainable=trainable, trainable=trainable,
name=name, **kwargs) name=name,
**kwargs)
def build(self, input_shape): def build(self, input_shape):
if len(input_shape) != 5: if len(input_shape) != 5:
raise ValueError('Inputs should have rank 5, ' + raise ValueError('Inputs should have rank 5, received input shape:',
'received input shape:', str(input_shape)) str(input_shape))
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
channel_axis = 1 channel_axis = 1
else: else:
@ -1285,22 +1287,23 @@ class Conv3DTranspose(Conv3D):
input_dim = input_shape[channel_axis] input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (self.filters, input_dim) kernel_shape = self.kernel_size + (self.filters, input_dim)
self.kernel = self.add_variable('kernel', self.kernel = self.add_variable(
shape=kernel_shape, 'kernel',
initializer=self.kernel_initializer, shape=kernel_shape,
regularizer=self.kernel_regularizer, initializer=self.kernel_initializer,
trainable=True, regularizer=self.kernel_regularizer,
dtype=self.dtype) trainable=True,
dtype=self.dtype)
if self.use_bias: if self.use_bias:
self.bias = self.add_variable('bias', self.bias = self.add_variable(
shape=(self.filters,), 'bias',
initializer=self.bias_initializer, shape=(self.filters,),
regularizer=self.bias_regularizer, initializer=self.bias_initializer,
trainable=True, regularizer=self.bias_regularizer,
dtype=self.dtype) trainable=True,
dtype=self.dtype)
else: else:
self.bias = None self.bias = None
self.built = True
def call(self, inputs): def call(self, inputs):
inputs_shape = array_ops.shape(inputs) inputs_shape = array_ops.shape(inputs)
@ -1343,26 +1346,26 @@ class Conv3DTranspose(Conv3D):
# Infer the static output shape: # Infer the static output shape:
out_shape = inputs.get_shape().as_list() out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters out_shape[c_axis] = self.filters
out_shape[d_axis] = utils.get_deconv_dim( out_shape[d_axis] = utils.get_deconv_dim(out_shape[d_axis], stride_d,
out_shape[d_axis], stride_d, kernel_d, self.padding) kernel_d, self.padding)
out_shape[h_axis] = utils.get_deconv_dim( out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
out_shape[h_axis], stride_h, kernel_h, self.padding) kernel_h, self.padding)
out_shape[w_axis] = utils.get_deconv_dim( out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
out_shape[w_axis], stride_w, kernel_w, self.padding) kernel_w, self.padding)
outputs.set_shape(out_shape) outputs.set_shape(out_shape)
if self.bias: if self.bias:
outputs_shape = outputs.shape.as_list() outputs_shape = outputs.shape.as_list()
if self.data_format == 'channels_first': if self.data_format == 'channels_first':
outputs_4d = array_ops.reshape(outputs, outputs_4d = array_ops.reshape(outputs, [
[outputs_shape[0], outputs_shape[1], outputs_shape[0], outputs_shape[1],
outputs_shape[2] * outputs_shape[3], outputs_shape[2] * outputs_shape[3], outputs_shape[4]
outputs_shape[4]]) ])
else: else:
outputs_4d = array_ops.reshape(outputs, outputs_4d = array_ops.reshape(outputs, [
[outputs_shape[0], outputs_shape[0], outputs_shape[1] * outputs_shape[2],
outputs_shape[1] * outputs_shape[2], outputs_shape[3], outputs_shape[4]
outputs_shape[3], outputs_shape[4]]) ])
outputs_4d = nn.bias_add( outputs_4d = nn.bias_add(
outputs_4d, outputs_4d,
self.bias, self.bias,

View File

@ -715,8 +715,8 @@ class Conv3DTransposeTest(test.TestCase):
layer = conv_layers.Conv3DTranspose( layer = conv_layers.Conv3DTranspose(
32, volumes.get_shape()[1:4], padding='same') 32, volumes.get_shape()[1:4], padding='same')
output = layer.apply(volumes) output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(), [5, depth, height, self.assertListEqual(output.get_shape().as_list(),
width, 32]) [5, depth, height, width, 32])
def testCreateConv3DTransposeWithStrides(self): def testCreateConv3DTransposeWithStrides(self):
depth, height, width = 4, 6, 8 depth, height, width = 4, 6, 8
@ -729,8 +729,7 @@ class Conv3DTransposeTest(test.TestCase):
[5, depth * 2, height * 2, width * 2, 4]) [5, depth * 2, height * 2, width * 2, 4])
# Test strides integer. # Test strides integer.
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, padding='same')
padding='same')
output = layer.apply(volumes) output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(), self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height * 2, width * 2, 4]) [5, depth * 2, height * 2, width * 2, 4])
@ -779,14 +778,14 @@ class Conv3DTransposeTest(test.TestCase):
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1) volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2) self.assertEqual(len(variables.trainable_variables()), 2)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True) conv_layers.conv3d_transpose(
volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
self.assertEqual(len(variables.trainable_variables()), 2) self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeReuseFromScope(self): def testFunctionalConv3DTransposeReuseFromScope(self):
with variable_scope.variable_scope('scope'): with variable_scope.variable_scope('scope'):
depth, height, width = 5, 7, 9 depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2) self.assertEqual(len(variables.trainable_variables()), 2)
with variable_scope.variable_scope('scope', reuse=True): with variable_scope.variable_scope('scope', reuse=True):
@ -798,8 +797,8 @@ class Conv3DTransposeTest(test.TestCase):
with variable_scope.variable_scope( with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()): 'scope', initializer=init_ops.ones_initializer()):
depth, height, width = 5, 7, 9 depth, height, width = 5, 7, 9
volumes = random_ops.random_uniform((5, depth, height, width, 32), volumes = random_ops.random_uniform(
seed=1) (5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1') conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
weights = variables.trainable_variables() weights = variables.trainable_variables()
# Check the names of weights in order. # Check the names of weights in order.

View File

@ -205,7 +205,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
`decoded.shape`: Shape vector, size `(2)`. `decoded.shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length]` The shape values are: `[batch_size, max_decoded_length]`
neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
sequence found, the negative of the sum of the greatest logit at each timeframe. sequence found, the negative of the sum of the greatest logit at each
timeframe.
""" """
outputs = gen_ctc_ops._ctc_greedy_decoder( outputs = gen_ctc_ops._ctc_greedy_decoder(
inputs, sequence_length, merge_repeated=merge_repeated) inputs, sequence_length, merge_repeated=merge_repeated)

View File

@ -39,6 +39,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops

View File

@ -964,8 +964,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
ValueError: If input/output depth does not match `filters`' shape, or if ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`. padding is other than `'VALID'` or `'SAME'`.
""" """
return convolution(input=value, filter=filters, padding=padding, return convolution(
dilation_rate=np.broadcast_to(rate, (2, )), name=name) input=value,
filter=filters,
padding=padding,
dilation_rate=np.broadcast_to(rate, (2,)),
name=name)
def conv2d_transpose(value, def conv2d_transpose(value,
@ -1231,8 +1235,8 @@ def conv3d_transpose(value,
axis = 1 if data_format == "NCDHW" else 4 axis = 1 if data_format == "NCDHW" else 4
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]): if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
raise ValueError("input channels does not match filter's input channels, " raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[axis], filter.get_shape( "{} != {}".format(value.get_shape()[axis],
)[4])) filter.get_shape()[4]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)): if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):

View File

@ -195,46 +195,47 @@ def load(sess, tags, export_dir, **saver_kwargs):
Raises: Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found. RuntimeError: MetaGraphDef associated with the tags cannot be found.
""" """
# Build the SavedModel protocol buffer and find the requested meta graph def. with sess.graph.as_default():
saved_model = _parse_saved_model(export_dir) # Build the SavedModel protocol buffer and find requested meta graph def.
found_match = False saved_model = _parse_saved_model(export_dir)
for meta_graph_def in saved_model.meta_graphs: found_match = False
if set(meta_graph_def.meta_info_def.tags) == set(tags): for meta_graph_def in saved_model.meta_graphs:
meta_graph_def_to_load = meta_graph_def if set(meta_graph_def.meta_info_def.tags) == set(tags):
found_match = True meta_graph_def_to_load = meta_graph_def
break found_match = True
break
if not found_match: if not found_match:
raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip( raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
"[]") + " could not be found in SavedModel") "[]") + " could not be found in SavedModel")
# Build a saver by importing the meta graph def to load. # Build a saver by importing the meta graph def to load.
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs) saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
if saver: if saver:
# Build the checkpoint path where the variables are located. # Build the checkpoint path where the variables are located.
variables_path = os.path.join( variables_path = os.path.join(
compat.as_bytes(export_dir), compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY), compat.as_bytes(constants.VARIABLES_DIRECTORY),
compat.as_bytes(constants.VARIABLES_FILENAME)) compat.as_bytes(constants.VARIABLES_FILENAME))
# Restore the variables using the built saver in the provided session. # Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path) saver.restore(sess, variables_path)
else: else:
tf_logging.info("The specified SavedModel has no variables; no " tf_logging.info("The specified SavedModel has no variables; no "
"checkpoints were restored.") "checkpoints were restored.")
# Get asset tensors, if any. # Get asset tensors, if any.
asset_tensors_dictionary = _get_asset_tensors(export_dir, asset_tensors_dictionary = _get_asset_tensors(export_dir,
meta_graph_def_to_load) meta_graph_def_to_load)
main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load) main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
if main_op_tensor is not None: if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
else: else:
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load) legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
if legacy_init_op_tensor is not None: if legacy_init_op_tensor is not None:
sess.run(fetches=[legacy_init_op_tensor], sess.run(
feed_dict=asset_tensors_dictionary) fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary)
return meta_graph_def_to_load return meta_graph_def_to_load

View File

@ -151,6 +151,27 @@ class SavedModelTest(test.TestCase):
constants.SAVED_MODEL_FILENAME_PBTXT): constants.SAVED_MODEL_FILENAME_PBTXT):
loader.load(sess, ["foo"], export_dir) loader.load(sess, ["foo"], export_dir)
def testVerifySessionGraphUsage(self):
export_dir = os.path.join(test.get_temp_dir(),
"test_verify_session_graph_usage")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
# Save the SavedModel to disk.
builder.save()
# Build a session and supply it to the load operation.
sess = session.Session(graph=ops.Graph())
loader.load(sess, [tag_constants.TRAINING], export_dir)
# Check the variable within the scope of the session and its graph.
with sess:
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
def testSequence(self): def testSequence(self):
export_dir = os.path.join(test.get_temp_dir(), "test_sequence") export_dir = os.path.join(test.get_temp_dir(), "test_sequence")
builder = saved_model_builder.SavedModelBuilder(export_dir) builder = saved_model_builder.SavedModelBuilder(export_dir)

View File

@ -12,33 +12,39 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ================================ # ================================
"""Imports a protobuf model as a graph in Tensorboard."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary
def import_to_tensorboard(model_dir, log_dir): def import_to_tensorboard(model_dir, log_dir):
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard. """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
Args: Args:
model_dir: The location of the protobuf (`pb`) model to visualize model_dir: The location of the protobuf (`pb`) model to visualize
log_dir: The location for the Tensorboard log to begin visualisation from. log_dir: The location for the Tensorboard log to begin visualisation from.
Usage: Usage:
Call this function with your model location and desired log directory. Call this function with your model location and desired log directory.
Launch Tensorboard by pointing it to the log directory. Launch Tensorboard by pointing it to the log directory.
View your imported `.pb` model as a graph. View your imported `.pb` model as a graph.
""" """
with tf.Session(graph=tf.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
with tf.gfile.FastGFile(model_dir, 'rb') as f: with gfile.FastGFile(model_dir, "rb") as f:
graph_def = tf.GraphDef() graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def) importer.import_graph_def(graph_def)
pb_visual_writer = tf.summary.FileWriter(log_dir) pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph) pb_visual_writer.add_graph(sess.graph)
print("Model Imported. Visualize by running: " print("Model Imported. Visualize by running: "
"> tensorboard --logdir={}".format(log_dir)) "> tensorboard --logdir={}".format(log_dir))

View File

@ -504,7 +504,14 @@ def run(args):
Args: Args:
args: A namespace parsed from command line. args: A namespace parsed from command line.
Raises:
AttributeError: An error when neither --inputs nor --input_exprs is passed
to run command.
""" """
if not args.inputs and not args.input_exprs:
raise AttributeError(
'At least one of --inputs and --input_exprs must be required')
tensor_key_feed_dict = load_inputs_from_input_arg_string( tensor_key_feed_dict = load_inputs_from_input_arg_string(
args.inputs, args.input_exprs) args.inputs, args.input_exprs)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def, run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
@ -629,8 +636,6 @@ def create_parser():
def main(): def main():
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
if not args.inputs and not args.input_exprs:
args.error('At least one of --inputs and --input_exprs is required')
args.func(args) args.func(args)

View File

@ -409,6 +409,16 @@ Method name is: tensorflow/serving/predict"""
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
saved_model_cli.run(args) saved_model_cli.run(args)
def testRunCommandInputNotGivenError(self):
self.parser = saved_model_cli.create_parser()
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
args = self.parser.parse_args([
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
'serving_default'
])
with self.assertRaises(AttributeError):
saved_model_cli.run(args)
def testRunCommandWithDebuggerEnabled(self): def testRunCommandWithDebuggerEnabled(self):
self.parser = saved_model_cli.create_parser() self.parser = saved_model_cli.create_parser()
base_path = test.test_src_dir_path(SAVED_MODEL_PATH) base_path = test.test_src_dir_path(SAVED_MODEL_PATH)

View File

@ -210,9 +210,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
else: else:
var_name = ",".join([v.name for v in var]) var_name = ",".join([v.name for v in var])
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
logging.info("Initialize variable %s from checkpoint %s with %s" % ( logging.info("Initialize variable %s from checkpoint %s with %s",
var_name, ckpt_dir_or_file, tensor_name_in_ckpt var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
))
else: else:
scopes = "" scopes = ""
# TODO(vihanjain): Support list of 'current_var_or_name' here. # TODO(vihanjain): Support list of 'current_var_or_name' here.
@ -250,9 +249,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
if var is None: if var is None:
var = _collect_partitioned_variable(var_name, store_vars) var = _collect_partitioned_variable(var_name, store_vars)
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
logging.info("Initialize variable %s from checkpoint %s with %s" % ( logging.info("Initialize variable %s from checkpoint %s with %s",
var_name, ckpt_dir_or_file, full_tensor_name var_name, ckpt_dir_or_file, full_tensor_name)
))
def _get_checkpoint_filename(ckpt_dir_or_file): def _get_checkpoint_filename(ckpt_dir_or_file):

View File

@ -935,11 +935,11 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except errors.OpError as e: except errors.OpError as e:
# It's ok if the file cannot be read # It's ok if the file cannot be read
logging.warning("%s: %s" % (type(e).__name__, e)) logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None return None
except text_format.ParseError as e: except text_format.ParseError as e:
logging.warning("%s: %s" % (type(e).__name__, e)) logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None return None
finally: finally:

View File

@ -230,13 +230,15 @@ class TensorboardServerTest(test.TestCase):
def testScalars(self): def testScalars(self):
"""Test the format of /data/scalars.""" """Test the format of /data/scalars."""
data = self._getJson('/data/scalars?run=run1&tag=simple_values') data = self._getJson('/data/scalars?run=run1&tag=simple_values')
self.assertEqual(len(data),self._SCALAR_COUNT) self.assertEqual(len(data), self._SCALAR_COUNT)
def testScalarsCsv(self): def testScalarsCsv(self):
"""Test the csv format of /data/scalars.""" """Test the csv format of /data/scalars."""
data = self._get('/data/scalars?run=run1&tag=simple_values&format=csv').read() data = self._get(
'/data/scalars?run=run1&tag=simple_values&format=csv').read()
line_count = data.count('\n') line_count = data.count('\n')
self.assertEqual(line_count,self._SCALAR_COUNT + 1) # include 1 more line for header self.assertEqual(line_count,
self._SCALAR_COUNT + 1) # include 1 more line for header
def testHistograms(self): def testHistograms(self):
"""Test the format of /data/histograms.""" """Test the format of /data/histograms."""

View File

@ -0,0 +1,63 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
licenses(["notice"]) # Apache 2.0
webfiles(
name = "tf_audio_dashboard",
srcs = [
"tf-audio-dashboard.html",
"tf-audio-grid.html",
"tf-audio-loader.html",
],
path = "/tf-audio-dashboard",
deps = [
"//tensorflow/tensorboard/components/tf_backend",
"//tensorflow/tensorboard/components/tf_dashboard_common",
"//tensorflow/tensorboard/components/tf_imports:lodash",
"@org_polymer",
"@org_polymer_paper_icon_button",
"@org_polymer_paper_slider",
"@org_polymer_paper_spinner",
"@org_polymer_paper_styles",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)
################################################################################
# MARKED FOR DELETION
tensorboard_webcomponent_library(
name = "legacy",
srcs = [
"tf-audio-dashboard.html",
"tf-audio-grid.html",
"tf-audio-loader.html",
],
destdir = "tf-audio-dashboard",
deps = [
"//tensorflow/tensorboard/components:tf_imports",
"//tensorflow/tensorboard/components/tf_backend:legacy",
"//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
"//third_party/javascript/polymer/v1/paper-styles:lib",
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
# This is needed: components/BUILD seeks a legacy_ts rule in this package.
tensorboard_ts_library(
name = "legacy_ts",
srcs = [],
deps_mgmt = "off",
runtime = "nodejs",
deps = ["//tensorflow/tensorboard/components:common_deps"],
)

View File

@ -0,0 +1,26 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
licenses(["notice"]) # Apache 2.0
# bazel run //third_party/tensorflow/tensorboard/components/tf_audio_dashboard/demo
webfiles(
name = "demo",
srcs = ["index.html"],
path = "/tf-audio-dashboard/demo",
deps = [
"//tensorflow/tensorboard/components/tf_audio_dashboard",
"//tensorflow/tensorboard/components/tf_audio_dashboard/demo/data",
"//tensorflow/tensorboard/components/tf_imports:d3",
"@org_polymer_iron_demo_helpers",
"@org_polymer_paper_styles",
"@org_polymer_webcomponentsjs",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

View File

@ -0,0 +1,17 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
licenses(["notice"]) # Apache 2.0
webfiles(
name = "data",
srcs = glob(["*"]),
path = "/tf-audio-dashboard/demo/data",
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

View File

@ -107,6 +107,8 @@ future for loading older clips.
</template> </template>
</template> </template>
<script> <script>
"use strict";
Polymer({ Polymer({
is: "tf-audio-loader", is: "tf-audio-loader",
properties: { properties: {

View File

@ -0,0 +1,81 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
licenses(["notice"]) # Apache 2.0
# TODO(dandelion): Add webfiles support for the test code.
webfiles(
name = "tf_backend",
srcs = [
"tf-backend.html",
":ts",
],
path = "/tf-backend",
deps = [
"//tensorflow/tensorboard/components/tf_imports:d3",
"//tensorflow/tensorboard/components/tf_imports:lodash",
"//tensorflow/tensorboard/components/vz_sorting",
"@org_polymer",
],
)
tensorboard_typescript_genrule(
name = "ts",
srcs = [
"backend.ts",
"behavior.ts",
"requestManager.ts",
"router.ts",
"urlPathHelpers.ts",
],
typings = [
"@org_definitelytyped//:d3.d.ts",
"@org_definitelytyped//:lodash.d.ts",
"//tensorflow/tensorboard/components/vz_sorting:ts_typings",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)
################################################################################
# MARKED FOR DELETION
tensorboard_webcomponent_library(
name = "legacy",
srcs = [
"tf-backend.html",
":legacy_ts",
],
visibility = ["//visibility:public"],
destdir = "tf-backend",
deps = [
"//tensorflow/tensorboard/components:tf_imports",
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
tensorboard_ts_library(
name = "legacy_ts",
srcs = [
"backend.ts",
"behavior.ts",
"requestManager.ts",
"router.ts",
"urlPathHelpers.ts",
],
deps_mgmt = "off",
runtime = "nodejs",
deps = [
"//tensorflow/tensorboard/components:common_deps",
"//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
],
)

View File

@ -0,0 +1,45 @@
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
licenses(["notice"]) # Apache 2.0
tensorboard_ts_library(
name = "ts",
srcs = [
"backend.ts",
"behavior.ts",
"requestManager.ts",
"router.ts",
"urlPathHelpers.ts",
],
deps = [
"//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
"//third_party/javascript/node_modules/typescript:es2015.promise",
"//third_party/javascript/typings/chai",
"//third_party/javascript/typings/d3_v4:bundle",
"//third_party/javascript/typings/lodash",
"//third_party/javascript/typings/mocha",
"//third_party/javascript/typings/polymer:polymer_without_externs",
"//third_party/javascript/typings/sinon",
],
)
# TODO(dandelion): Add runners for these tests
tensorboard_ts_library(
name = "tests",
srcs = [
"backendTests.ts",
"behaviorTests.ts",
"requestManagerTests.ts",
],
deps = [
":ts",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

View File

@ -0,0 +1,65 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
licenses(["notice"]) # Apache 2.0
# TODO(dandelion): Add webfiles support for the test code.
webfiles(
name = "tf_color_scale",
srcs = [
"tf-color-scale.html",
":ts",
],
path = "/tf-color-scale",
deps = [
"//tensorflow/tensorboard/components/tf_imports:d3",
"@org_polymer",
],
)
tensorboard_typescript_genrule(
name = "ts",
srcs = [
"colorScale.ts",
"palettes.ts",
],
typings = ["@org_definitelytyped//:d3.d.ts"],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)
################################################################################
# MARKED FOR DELETION
tensorboard_webcomponent_library(
name = "legacy",
srcs = [
"tf-color-scale.html",
":legacy_ts",
],
destdir = "tf-color-scale",
deps = [
"//tensorflow/tensorboard/components:tf_imports",
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
tensorboard_ts_library(
name = "legacy_ts",
srcs = [
"colorScale.ts",
"palettes.ts",
],
deps_mgmt = "off",
runtime = "nodejs",
deps = ["//tensorflow/tensorboard/components:common_deps"],
)

View File

@ -0,0 +1,26 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
licenses(["notice"]) # Apache 2.0
# bazel run //third_party/tensorflow/tensorboard/components/tf_color_scale/demo
webfiles(
name = "demo",
srcs = ["index.html"],
path = "/tf-color-scale/demo",
deps = [
"//tensorflow/tensorboard/components/tf_color_scale",
"//tensorflow/tensorboard/components/tf_imports:d3",
"@org_polymer_iron_demo_helpers",
"@org_polymer_paper_button",
"@org_polymer_paper_styles",
"@org_polymer_webcomponentsjs",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

View File

@ -0,0 +1,72 @@
package(default_visibility = ["//tensorflow:internal"])
load(
"//tensorflow/tensorboard:defs.bzl",
"tensorboard_ts_development_sources",
"tensorboard_ts_devserver",
"tensorboard_ts_library",
"tensorboard_webcomponent_library",
)
licenses(["notice"]) # Apache 2.0
# TODO(dandelion): Add runner for the test code.
tensorboard_webcomponent_library(
name = "tf_color_scale",
srcs = ["tf-color-scale.html"],
ts_lib_deps = [":ts"],
deps = [
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
tensorboard_ts_library(
name = "ts",
srcs = [
"colorScale.ts",
"palettes.ts",
],
deps = ["//tensorflow/tensorboard/components:common_deps_d3v4"],
)
tensorboard_ts_library(
name = "tests",
srcs = ["colorScaleTests.ts"],
deps = [
":ts",
"//tensorflow/tensorboard/components:common_deps_d3v4",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)
tensorboard_webcomponent_library(
name = "demo",
srcs = ["demo.html"],
visibility = ["//visibility:public"],
deps = [
":tf_color_scale",
"//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
"//third_party/javascript/polymer/v1/paper-button:lib",
"//third_party/javascript/polymer/v1/paper-styles:lib",
],
)
tensorboard_ts_devserver(
name = "devserver",
manifest = ":dev_sources",
serving_path = "/demo_out/bundle.js",
static_files = [":demo"],
)
tensorboard_ts_development_sources(
name = "dev_sources",
deps = [
":ts",
],
)

View File

@ -0,0 +1,102 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
licenses(["notice"]) # Apache 2.0
webfiles(
name = "tf_dashboard_common",
srcs = glob(["*.html"]) + [
":ts",
],
path = "/tf-dashboard-common",
deps = [
"//tensorflow/tensorboard/components/tf_imports:lodash",
"//tensorflow/tensorboard/components/tf_imports:plottable",
"//tensorflow/tensorboard/components/tf_storage",
"//tensorflow/tensorboard/components/vz_sorting",
"@org_polymer",
"@org_polymer_iron_ajax",
"@org_polymer_iron_collapse",
"@org_polymer_iron_icons",
"@org_polymer_paper_button",
"@org_polymer_paper_checkbox",
"@org_polymer_paper_dialog",
"@org_polymer_paper_dropdown_menu",
"@org_polymer_paper_icon_button",
"@org_polymer_paper_input",
"@org_polymer_paper_item",
"@org_polymer_paper_menu",
"@org_polymer_paper_slider",
"@org_polymer_paper_spinner",
"@org_polymer_paper_styles",
"@org_polymer_paper_toggle_button",
],
)
tensorboard_typescript_genrule(
name = "ts",
srcs = [
"categorizer.ts",
"dashboard-behavior.ts",
"reload-behavior.ts",
],
typings = [
"@org_definitelytyped//:d3.d.ts",
"//tensorflow/tensorboard/components/vz_sorting:ts_typings",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)
################################################################################
# MARKED FOR DELETION
tensorboard_webcomponent_library(
name = "legacy",
srcs = glob(["*.html"]) + [":legacy_ts"],
destdir = "tf-dashboard-common",
deps = [
"//tensorflow/tensorboard/components:tf_imports",
"//tensorflow/tensorboard/components/tf_storage:legacy",
"//tensorflow/tensorboard/components/vz_sorting:legacy",
"//third_party/javascript/polymer/v1/iron-ajax:lib",
"//third_party/javascript/polymer/v1/iron-collapse:lib",
"//third_party/javascript/polymer/v1/iron-icons:lib",
"//third_party/javascript/polymer/v1/paper-button:lib",
"//third_party/javascript/polymer/v1/paper-checkbox:lib",
"//third_party/javascript/polymer/v1/paper-dialog:lib",
"//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
"//third_party/javascript/polymer/v1/paper-input:lib",
"//third_party/javascript/polymer/v1/paper-item:lib",
"//third_party/javascript/polymer/v1/paper-menu:lib",
"//third_party/javascript/polymer/v1/paper-slider:lib",
"//third_party/javascript/polymer/v1/paper-spinner:lib",
"//third_party/javascript/polymer/v1/paper-styles:lib",
"//third_party/javascript/polymer/v1/paper-toggle-button:lib",
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
tensorboard_ts_library(
name = "legacy_ts",
srcs = [
"categorizer.ts",
"dashboard-behavior.ts",
"reload-behavior.ts",
],
deps_mgmt = "off",
runtime = "nodejs",
deps = [
"//tensorflow/tensorboard/components:common_deps",
"//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
],
)

View File

@ -0,0 +1,31 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
licenses(["notice"]) # Apache 2.0
# bazel run //third_party/tensorflow/tensorboard/components/tf_dashboard_common/demo
webfiles(
name = "demo",
srcs = [
"tf-categorizer-demo.html",
"tf-collapsable-pane-demo.html",
"tf-multi-checkbox-demo.html",
"tf-regex-group-demo.html",
],
path = "/tf-dashboard-common/demo",
deps = [
"//tensorflow/tensorboard/components/tf_color_scale",
"//tensorflow/tensorboard/components/tf_dashboard_common",
"//tensorflow/tensorboard/components/tf_imports:d3",
"@org_polymer_iron_flex_layout",
"@org_polymer_paper_styles",
"@org_polymer_webcomponentsjs",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

View File

@ -57,6 +57,8 @@ plugin is requred to implement two functions:
</style> </style>
</template> </template>
<script> <script>
"use strict";
Polymer({ Polymer({
is: "tf-chart-scaffold", is: "tf-chart-scaffold",
properties: { properties: {

View File

@ -0,0 +1,114 @@
package(default_visibility = ["//tensorflow:internal"])
load(
"//tensorflow/tensorboard:defs.bzl",
"tensorboard_ts_development_sources",
"tensorboard_ts_devserver",
"tensorboard_ts_library",
"tensorboard_webcomponent_library",
)
licenses(["notice"]) # Apache 2.0
tensorboard_webcomponent_library(
name = "tf_dashboard_common",
srcs = [
"dashboard-style.html",
"run-color-style.html",
"scrollbar-style.html",
"tensorboard-color.html",
"tf-categorizer.html",
"tf-collapsable-pane.html",
"tf-dashboard.html",
"tf-dashboard-layout.html",
"tf-downloader.html",
"tf-multi-checkbox.html",
"tf-no-data-warning.html",
"tf-option-selector.html",
"tf-panes-helper.html",
"tf-regex-group.html",
"tf-run-selector.html",
"tf-sidebar-helper.html",
],
ts_lib_deps = [":ts"],
deps = [
"//third_party/javascript/plottable/v3:lib",
"//third_party/javascript/polymer/v1/iron-ajax:lib",
"//third_party/javascript/polymer/v1/iron-collapse:lib",
"//third_party/javascript/polymer/v1/iron-icons:lib",
"//third_party/javascript/polymer/v1/paper-button:lib",
"//third_party/javascript/polymer/v1/paper-checkbox:lib",
"//third_party/javascript/polymer/v1/paper-dialog:lib",
"//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
"//third_party/javascript/polymer/v1/paper-input:lib",
"//third_party/javascript/polymer/v1/paper-item:lib",
"//third_party/javascript/polymer/v1/paper-menu:lib",
"//third_party/javascript/polymer/v1/paper-slider:lib",
"//third_party/javascript/polymer/v1/paper-spinner:lib",
"//third_party/javascript/polymer/v1/paper-styles:lib",
"//third_party/javascript/polymer/v1/paper-toggle-button:lib",
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
tensorboard_ts_library(
name = "ts",
srcs = [
"dashboard-behavior.ts",
"reload-behavior.ts",
"tf-categorizer.ts",
"tf-multi-checkbox.ts",
"tf-regex-group.ts",
],
deps = [
"//tensorflow/tensorboard/components:common_deps_d3v4",
"//tensorflow/tensorboard/components/tf_storage_d3v4:ts",
"//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
],
)
tensorboard_ts_library(
name = "tests",
srcs = ["tf-categorizer-tests.ts"],
deps = [
":ts",
"//tensorflow/tensorboard/components:common_deps_d3v4",
],
)
tensorboard_webcomponent_library(
name = "demo",
srcs = [
"tf-categorizer-demo.html",
"tf-collapsable-pane-demo.html",
"tf-multi-checkbox-demo.html",
"tf-regex-group-demo.html",
],
deps = [
":tf_dashboard_common",
"//tensorflow/tensorboard/components/tf_color_scale_d3v4:tf_color_scale",
"//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
"//third_party/javascript/polymer/v1/paper-styles:lib",
],
)
tensorboard_ts_devserver(
name = "devserver",
manifest = ":dev_sources",
serving_path = "/demo_out/bundle.js",
static_files = [":demo"],
)
tensorboard_ts_development_sources(
name = "dev_sources",
deps = [
":ts",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

View File

@ -55,6 +55,8 @@ plugin is requred to implement two functions:
</style> </style>
</template> </template>
<script> <script>
"use strict";
Polymer({ Polymer({
is: "tf-chart-scaffold", is: "tf-chart-scaffold",
properties: { properties: {

View File

@ -0,0 +1,63 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
licenses(["notice"]) # Apache 2.0
webfiles(
name = "tf_distribution_dashboard",
srcs = [
"tf-distribution-dashboard.html",
],
path = "/tf-distribution-dashboard",
deps = [
"//tensorflow/tensorboard/components/tf_backend",
"//tensorflow/tensorboard/components/tf_color_scale",
"//tensorflow/tensorboard/components/tf_dashboard_common",
"//tensorflow/tensorboard/components/tf_imports:lodash",
"//tensorflow/tensorboard/components/vz_distribution_chart",
"@org_polymer",
"@org_polymer_iron_collapse",
"@org_polymer_paper_icon_button",
"@org_polymer_paper_styles",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)
################################################################################
# MARKED FOR DELETION
tensorboard_webcomponent_library(
name = "legacy",
srcs = [
"tf-distribution-dashboard.html",
":legacy_ts",
],
destdir = "tf-distribution-dashboard",
deps = [
"//tensorflow/tensorboard/components:tf_imports",
"//tensorflow/tensorboard/components/tf_backend:legacy",
"//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
"//tensorflow/tensorboard/components/vz_distribution_chart:legacy",
"//third_party/javascript/polymer/v1/iron-collapse:lib",
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
"//third_party/javascript/polymer/v1/paper-styles:lib",
"//third_party/javascript/polymer/v1/polymer:lib",
],
)
# This is needed: components/BUILD seeks a legacy_ts rule in this package.
tensorboard_ts_library(
name = "legacy_ts",
srcs = [],
deps_mgmt = "off",
runtime = "nodejs",
deps = ["//tensorflow/tensorboard/components:common_deps"],
)

View File

@ -0,0 +1,26 @@
package(default_visibility = ["//tensorflow:internal"])
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
licenses(["notice"]) # Apache 2.0
# bazel run //third_party/tensorflow/tensorboard/components/tf_distribution_dashboard/demo
webfiles(
name = "demo",
srcs = ["index.html"],
path = "/tf-distribution-dashboard/demo",
deps = [
"//tensorflow/tensorboard/components/tf_distribution_dashboard",
"//tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data",
"//tensorflow/tensorboard/components/tf_imports:d3",
"@org_polymer_iron_demo_helpers",
"@org_polymer_paper_styles",
"@org_polymer_webcomponentsjs",
],
)
filegroup(
name = "all_files",
srcs = glob(["**"]),
tags = ["notsan"],
)

Some files were not shown because too many files have changed in this diff Show More