Merge commit for internal changes
This commit is contained in:
commit
9dd8e7aec9
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,11 +7,8 @@ node_modules
|
||||
/bazel_pip
|
||||
/third_party/eigen3/mkl_include
|
||||
/third_party/mkl/*
|
||||
/third_party/py/numpy/numpy_include
|
||||
/tools/python_bin_path.sh
|
||||
/tools/git/gen
|
||||
/util/python/python_include
|
||||
/util/python/python_lib
|
||||
/pip_test
|
||||
/_python_build
|
||||
*.pyc
|
||||
|
@ -263,6 +263,7 @@ filegroup(
|
||||
"//tensorflow/contrib/seq2seq:all_files",
|
||||
"//tensorflow/contrib/session_bundle:all_files",
|
||||
"//tensorflow/contrib/session_bundle/example:all_files",
|
||||
"//tensorflow/contrib/signal:all_files",
|
||||
"//tensorflow/contrib/slim:all_files",
|
||||
"//tensorflow/contrib/slim/python/slim/data:all_files",
|
||||
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
||||
@ -326,6 +327,48 @@ filegroup(
|
||||
"//tensorflow/tensorboard/backend:all_files",
|
||||
"//tensorflow/tensorboard/backend/event_processing:all_files",
|
||||
"//tensorflow/tensorboard/components:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_backend:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_backend_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_globals:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_globals_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_graph_common:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_imports:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_imports_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_storage:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_storage_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_data_summary:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_projector:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/lib:all_files",
|
||||
"//tensorflow/tensorboard/plugins:all_files",
|
||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||
|
@ -28,11 +28,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
using ::testing::_;
|
||||
|
||||
class HloRematerializationTest : public HloTestBase {
|
||||
protected:
|
||||
// Creates and returns a computation which can benefit from
|
||||
@ -145,11 +147,9 @@ TEST_F(HloRematerializationTest, SingleComputation) {
|
||||
// Find and save the original broadcast instruction which should be
|
||||
// rematerialized.
|
||||
const HloInstruction* slice = computation->root_instruction();
|
||||
ASSERT_EQ(HloOpcode::kSlice, slice->opcode());
|
||||
ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
|
||||
const HloInstruction* concat = slice->operand(0);
|
||||
ASSERT_EQ(HloOpcode::kConcatenate, concat->opcode());
|
||||
const HloInstruction* bcast = concat->operand(0);
|
||||
ASSERT_EQ(HloOpcode::kBroadcast, bcast->opcode());
|
||||
|
||||
SequentialHloOrdering::HloModuleSequence sequence;
|
||||
// Computation requires 16KB without rematerialization, but uses only 12KB
|
||||
@ -165,8 +165,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
|
||||
|
||||
// The broadcast should have been rematerialized.
|
||||
const HloInstruction* remat_bcast = concat->operand(0);
|
||||
EXPECT_EQ(HloOpcode::kBroadcast, remat_bcast->opcode());
|
||||
EXPECT_NE(bcast, remat_bcast);
|
||||
EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
|
||||
|
||||
// The rematerialized broadcast should be immediate before the concat in the
|
||||
// sequence.
|
||||
|
@ -68,9 +68,8 @@ void CleanNodeName(string* name) {
|
||||
}
|
||||
|
||||
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
|
||||
LOG(INFO) << "Adding computation " << computation.name();
|
||||
VLOG(2) << "Adding computation " << computation.name();
|
||||
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
|
||||
LOG(INFO) << "Adding embedded computation " << embedded->name();
|
||||
for (auto& instruction : embedded->instructions()) {
|
||||
TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
|
||||
maxval=math_ops.cast(num_data, dtypes.int64),
|
||||
seed=random_seed,
|
||||
dtype=dtypes.int64)
|
||||
indices = indices % math_ops.cast(num_data, dtypes.int64)
|
||||
indices %= math_ops.cast(num_data, dtypes.int64)
|
||||
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
|
||||
return clusters_init
|
||||
|
||||
|
@ -35,8 +35,8 @@ class GridRNNCellTest(test.TestCase):
|
||||
|
||||
def testGrid2BasicLSTMCell(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.2)) as root_scope:
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||
@ -51,21 +51,22 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
self.assertEqual(res_s[1].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g, ([[0.36617181, 0.36617181]], ))
|
||||
self.assertAllClose(res_s, (([[0.71053141, 0.71053141]],
|
||||
[[0.36617181, 0.36617181]]),
|
||||
([[0.72320831, 0.80555487]],
|
||||
[[0.39102408, 0.42150158]])))
|
||||
self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],))
|
||||
self.assertAllClose(
|
||||
res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
|
||||
([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
|
||||
|
||||
# emulate a loop through the input sequence,
|
||||
# where we call cell() multiple times
|
||||
@ -78,17 +79,17 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s2[1].h.get_shape(), (1, 2))
|
||||
|
||||
res_g2, res_s2 = sess.run([g2, s2],
|
||||
{x: np.array([[2., 2., 2.]]), m: res_s})
|
||||
{x: np.array([[2., 2., 2.]]),
|
||||
m: res_s})
|
||||
self.assertEqual(res_g2[0].shape, (1, 2))
|
||||
self.assertEqual(res_s2[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s2[0].h.shape, (1, 2))
|
||||
self.assertEqual(res_s2[1].c.shape, (1, 2))
|
||||
self.assertEqual(res_s2[1].h.shape, (1, 2))
|
||||
self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
|
||||
self.assertAllClose(res_s2, (([[1.40469193, 1.40469193]],
|
||||
[[0.58847463, 0.58847463]]),
|
||||
([[0.97726452, 1.04626071]],
|
||||
[[0.4927212, 0.51137757]])))
|
||||
self.assertAllClose(
|
||||
res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]),
|
||||
([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]])))
|
||||
|
||||
def testGrid2BasicLSTMCellTied(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
@ -108,10 +109,12 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
@ -119,29 +122,27 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
|
||||
self.assertAllClose(res_s, (([[0.71053141, 0.71053141]],
|
||||
[[0.36617181, 0.36617181]]),
|
||||
([[0.72320831, 0.80555487]],
|
||||
[[0.39102408, 0.42150158]])))
|
||||
self.assertAllClose(
|
||||
res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
|
||||
([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
|
||||
|
||||
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
|
||||
self.assertAllClose(res_s, (([[0.71200621, 0.71200621]],
|
||||
[[0.36703536, 0.36703536]]),
|
||||
([[0.80941606, 0.87550586]],
|
||||
[[0.40108523, 0.42199609]])))
|
||||
self.assertAllClose(
|
||||
res_s, (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]),
|
||||
([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]])))
|
||||
|
||||
def testGrid2BasicLSTMCellWithRelu(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.2)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.2)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||
cell = grid_rnn_cell.Grid2BasicLSTMCell(
|
||||
2, tied=False, non_recurrent_fn=nn_ops.relu)
|
||||
self.assertEqual(cell.state_size, ((2, 2), ))
|
||||
self.assertEqual(cell.state_size, ((2, 2),))
|
||||
|
||||
g, s = cell(x, m)
|
||||
self.assertEqual(g[0].get_shape(), (1, 2))
|
||||
@ -149,21 +150,22 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x: np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
|
||||
self.assertAllClose(res_s, (([[0.29530135, 0.37520045]],
|
||||
[[0.17044567, 0.21292259]]), ))
|
||||
[[0.17044567, 0.21292259]]),))
|
||||
|
||||
"""LSTMCell
|
||||
"""
|
||||
|
||||
def testGrid2LSTMCell(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||
@ -178,10 +180,12 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
@ -189,15 +193,14 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
|
||||
[[0.95686918, 0.95686918]]),
|
||||
([[1.38917875, 1.49043763]],
|
||||
[[0.83884692, 0.86036491]])))
|
||||
self.assertAllClose(
|
||||
res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
|
||||
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
|
||||
|
||||
def testGrid2LSTMCellTied(self):
|
||||
with self.test_session(use_gpu=False) as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||
@ -212,10 +215,12 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
@ -223,15 +228,14 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
|
||||
[[0.95686918, 0.95686918]]),
|
||||
([[1.38917875, 1.49043763]],
|
||||
[[0.83884692, 0.86036491]])))
|
||||
self.assertAllClose(
|
||||
res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
|
||||
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
|
||||
|
||||
def testGrid2LSTMCellWithRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||
cell = grid_rnn_cell.Grid2LSTMCell(
|
||||
@ -244,21 +248,22 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x: np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
|
||||
self.assertAllClose(res_s, (([[0.92270052, 1.02325559]],
|
||||
[[0.66159075, 0.70475441]]), ))
|
||||
[[0.66159075, 0.70475441]]),))
|
||||
|
||||
"""RNNCell
|
||||
"""
|
||||
|
||||
def testGrid2BasicRNNCell(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([2, 2])
|
||||
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
|
||||
cell = grid_rnn_cell.Grid2BasicRNNCell(2)
|
||||
@ -270,26 +275,26 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[1].get_shape(), (2, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
||||
m: (np.array([[0.1, 0.1], [0.2, 0.2]]),
|
||||
np.array([[0.1, 0.1], [0.2, 0.2]]))})
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1.], [2., 2.]]),
|
||||
m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
|
||||
[0.2, 0.2]]))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (2, 2))
|
||||
self.assertEqual(res_s[0].shape, (2, 2))
|
||||
self.assertEqual(res_s[1].shape, (2, 2))
|
||||
|
||||
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
|
||||
[0.99480951, 0.99480951]], ))
|
||||
self.assertAllClose(res_s,
|
||||
([[0.94685763, 0.94685763],
|
||||
[0.99480951, 0.99480951]],
|
||||
[[0.80049908, 0.80049908],
|
||||
[0.97574311, 0.97574311]]))
|
||||
[0.99480951, 0.99480951]],))
|
||||
self.assertAllClose(
|
||||
res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
|
||||
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
|
||||
|
||||
def testGrid2BasicRNNCellTied(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([2, 2])
|
||||
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
|
||||
cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
|
||||
@ -301,55 +306,55 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[1].get_shape(), (2, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
||||
m: (np.array([[0.1, 0.1], [0.2, 0.2]]),
|
||||
np.array([[0.1, 0.1], [0.2, 0.2]]))})
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1.], [2., 2.]]),
|
||||
m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
|
||||
[0.2, 0.2]]))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (2, 2))
|
||||
self.assertEqual(res_s[0].shape, (2, 2))
|
||||
self.assertEqual(res_s[1].shape, (2, 2))
|
||||
|
||||
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
|
||||
[0.99480951, 0.99480951]], ))
|
||||
self.assertAllClose(res_s,
|
||||
([[0.94685763, 0.94685763],
|
||||
[0.99480951, 0.99480951]],
|
||||
[[0.80049908, 0.80049908],
|
||||
[0.97574311, 0.97574311]]))
|
||||
[0.99480951, 0.99480951]],))
|
||||
self.assertAllClose(
|
||||
res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
|
||||
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
|
||||
|
||||
def testGrid2BasicRNNCellWithRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = (array_ops.zeros([1, 2]), )
|
||||
cell = grid_rnn_cell.Grid2BasicRNNCell(
|
||||
2, non_recurrent_fn=nn_ops.relu)
|
||||
self.assertEqual(cell.state_size, (2, ))
|
||||
m = (array_ops.zeros([1, 2]),)
|
||||
cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu)
|
||||
self.assertEqual(cell.state_size, (2,))
|
||||
|
||||
g, s = cell(x, m)
|
||||
self.assertEqual(g[0].get_shape(), (1, 2))
|
||||
self.assertEqual(s[0].get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1.]]),
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1.]]),
|
||||
m: np.array([[0.1, 0.1]])})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].shape, (1, 2))
|
||||
self.assertAllClose(res_g, ([[1.80049896, 1.80049896]], ))
|
||||
self.assertAllClose(res_s, ([[0.80049896, 0.80049896]], ))
|
||||
self.assertAllClose(res_g, ([[1.80049896, 1.80049896]],))
|
||||
self.assertAllClose(res_s, ([[0.80049896, 0.80049896]],))
|
||||
|
||||
"""1-LSTM
|
||||
"""
|
||||
|
||||
def testGrid1LSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)) as root_scope:
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||
cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
|
||||
self.assertEqual(cell.state_size, ((2, 2), ))
|
||||
self.assertEqual(cell.state_size, ((2, 2),))
|
||||
|
||||
g, s = cell(x, m)
|
||||
self.assertEqual(g[0].get_shape(), (1, 2))
|
||||
@ -357,17 +362,17 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x: np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g, ([[0.91287315, 0.91287315]], ))
|
||||
self.assertAllClose(res_s,
|
||||
(([[2.26285243, 2.26285243]],
|
||||
[[0.91287315, 0.91287315]]), ))
|
||||
self.assertAllClose(res_g, ([[0.91287315, 0.91287315]],))
|
||||
self.assertAllClose(res_s, (([[2.26285243, 2.26285243]],
|
||||
[[0.91287315, 0.91287315]]),))
|
||||
|
||||
root_scope.reuse_variables()
|
||||
|
||||
@ -383,10 +388,9 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(res_s2[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s2[0].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]], ))
|
||||
self.assertAllClose(res_s2,
|
||||
(([[2.79966092, 2.79966092]],
|
||||
[[0.9032144, 0.9032144]]), ))
|
||||
self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]],))
|
||||
self.assertAllClose(res_s2, (([[2.79966092, 2.79966092]],
|
||||
[[0.9032144, 0.9032144]]),))
|
||||
|
||||
g3, s3 = cell(x2, m)
|
||||
self.assertEqual(g3[0].get_shape(), (1, 2))
|
||||
@ -398,18 +402,17 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(res_g3[0].shape, (1, 2))
|
||||
self.assertEqual(res_s3[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s3[0].h.shape, (1, 2))
|
||||
self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]], ))
|
||||
self.assertAllClose(res_s3,
|
||||
(([[3.3529923, 3.3529923]],
|
||||
[[0.92727238, 0.92727238]]), ))
|
||||
self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]],))
|
||||
self.assertAllClose(res_s3, (([[3.3529923, 3.3529923]],
|
||||
[[0.92727238, 0.92727238]]),))
|
||||
|
||||
"""3-LSTM
|
||||
"""
|
||||
|
||||
def testGrid3LSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||
@ -427,11 +430,13 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[2].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x:
|
||||
np.array([[1., 1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])),
|
||||
(np.array([[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))})
|
||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), (np.array(
|
||||
[[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))
|
||||
})
|
||||
self.assertEqual(res_g[0].shape, (1, 2))
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
@ -440,21 +445,19 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(res_s[2].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[2].h.shape, (1, 2))
|
||||
|
||||
self.assertAllClose(res_g, ([[0.96892911, 0.96892911]], ))
|
||||
self.assertAllClose(res_s, (([[2.45227885, 2.45227885]],
|
||||
[[0.96892911, 0.96892911]]),
|
||||
([[1.33592629, 1.4373529]],
|
||||
[[0.80867189, 0.83247656]]),
|
||||
([[0.7317788, 0.63205892]],
|
||||
[[0.56548983, 0.50446129]])))
|
||||
self.assertAllClose(res_g, ([[0.96892911, 0.96892911]],))
|
||||
self.assertAllClose(
|
||||
res_s, (([[2.45227885, 2.45227885]], [[0.96892911, 0.96892911]]),
|
||||
([[1.33592629, 1.4373529]], [[0.80867189, 0.83247656]]),
|
||||
([[0.7317788, 0.63205892]], [[0.56548983, 0.50446129]])))
|
||||
|
||||
"""Edge cases
|
||||
"""
|
||||
|
||||
def testGridRNNEdgeCasesLikeRelu(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([3, 2])
|
||||
m = ()
|
||||
|
||||
@ -471,18 +474,18 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s, ())
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
|
||||
res_g, res_s = sess.run([g, s],
|
||||
{x: np.array([[1., -1.], [-2, 1], [2, -1]])})
|
||||
self.assertEqual(res_g[0].shape, (3, 2))
|
||||
self.assertEqual(res_s, ())
|
||||
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]], ))
|
||||
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
|
||||
|
||||
def testGridRNNEdgeCasesNoOutput(self):
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 2])
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
|
||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||
|
||||
# This cell produces no output
|
||||
cell = grid_rnn_cell.GridRNNCell(
|
||||
@ -498,9 +501,10 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res_g, res_s = sess.run(
|
||||
[g, s], {x: np.array([[1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])), )})
|
||||
res_g, res_s = sess.run([g, s], {
|
||||
x: np.array([[1., 1.]]),
|
||||
m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])),)
|
||||
})
|
||||
self.assertEqual(res_g, ())
|
||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||
@ -561,8 +565,9 @@ class GridRNNCellTest(test.TestCase):
|
||||
cell = grid_rnn_cell.Grid2LSTMCell(
|
||||
num_units=num_units, non_recurrent_fn=nn_ops.relu)
|
||||
|
||||
inputs = max_length * [array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))]
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
@ -600,8 +605,9 @@ class GridRNNCellTest(test.TestCase):
|
||||
cell = grid_rnn_cell.Grid3LSTMCell(
|
||||
num_units=num_units, non_recurrent_fn=nn_ops.relu)
|
||||
|
||||
inputs = max_length * [array_ops.placeholder(
|
||||
dtypes.float32, shape=(batch_size, input_size))]
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
@ -671,19 +677,17 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
|
||||
def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
|
||||
"""Test for #4296
|
||||
"""
|
||||
"""Test for #4296."""
|
||||
input_size = 5
|
||||
max_length = 6 # unrolled up to this length
|
||||
num_units = 2
|
||||
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
|
||||
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
|
||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
@ -700,8 +704,7 @@ class GridRNNCellTest(test.TestCase):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
input_value = np.ones((3, input_size))
|
||||
values = sess.run(outputs + [state],
|
||||
feed_dict={inputs[0]: input_value})
|
||||
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
|
||||
for tp in values[:-1]:
|
||||
for v in tp:
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
@ -710,18 +713,15 @@ class GridRNNCellTest(test.TestCase):
|
||||
for v in st:
|
||||
self.assertTrue(np.all(np.isfinite(v)))
|
||||
|
||||
|
||||
def testGrid2LSTMCellLegacy(self):
|
||||
"""Test for legacy case (when state_is_tuple=False)
|
||||
"""
|
||||
"""Test for legacy case (when state_is_tuple=False)."""
|
||||
with self.test_session() as sess:
|
||||
with variable_scope.variable_scope('root',
|
||||
initializer=init_ops.constant_initializer(0.5)):
|
||||
with variable_scope.variable_scope(
|
||||
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = array_ops.zeros([1, 8])
|
||||
cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True,
|
||||
state_is_tuple=False,
|
||||
output_is_tuple=False)
|
||||
cell = grid_rnn_cell.Grid2LSTMCell(
|
||||
2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False)
|
||||
self.assertEqual(cell.state_size, 8)
|
||||
|
||||
g, s = cell(x, m)
|
||||
@ -729,15 +729,17 @@ class GridRNNCellTest(test.TestCase):
|
||||
self.assertEqual(s.get_shape(), (1, 8))
|
||||
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run(
|
||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
||||
res = sess.run([g, s], {
|
||||
x: np.array([[1., 1., 1.]]),
|
||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
|
||||
})
|
||||
self.assertEqual(res[0].shape, (1, 2))
|
||||
self.assertEqual(res[1].shape, (1, 8))
|
||||
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
|
||||
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
|
||||
0.95686918, 1.38917875, 1.49043763,
|
||||
0.83884692, 0.86036491]])
|
||||
self.assertAllClose(res[1], [[
|
||||
2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875,
|
||||
1.49043763, 0.83884692, 0.86036491
|
||||
]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -107,11 +107,11 @@ class GridRNNCell(rnn.RNNCell):
|
||||
TypeError: if cell_fn does not return an RNNCell instance.
|
||||
"""
|
||||
if not state_is_tuple:
|
||||
logging.warning("%s: Using a concatenated state is slower and will "
|
||||
"soon be deprecated. Use state_is_tuple=True.", self)
|
||||
logging.warning('%s: Using a concatenated state is slower and will '
|
||||
'soon be deprecated. Use state_is_tuple=True.', self)
|
||||
if not output_is_tuple:
|
||||
logging.warning("%s: Using a concatenated output is slower and will"
|
||||
"soon be deprecated. Use output_is_tuple=True.", self)
|
||||
logging.warning('%s: Using a concatenated output is slower and will'
|
||||
'soon be deprecated. Use output_is_tuple=True.', self)
|
||||
|
||||
if num_dims < 1:
|
||||
raise ValueError('dims must be >= 1: {}'.format(num_dims))
|
||||
@ -126,9 +126,7 @@ class GridRNNCell(rnn.RNNCell):
|
||||
|
||||
if cell_fn is None:
|
||||
my_cell_fn = functools.partial(
|
||||
rnn.LSTMCell,
|
||||
num_units=num_units,
|
||||
state_is_tuple=state_is_tuple)
|
||||
rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
|
||||
else:
|
||||
my_cell_fn = lambda: cell_fn(num_units)
|
||||
if tied:
|
||||
@ -136,9 +134,8 @@ class GridRNNCell(rnn.RNNCell):
|
||||
else:
|
||||
self._cells = [my_cell_fn() for _ in range(num_dims)]
|
||||
if not isinstance(self._cells[0], rnn.RNNCell):
|
||||
raise TypeError(
|
||||
'cell_fn must return an RNNCell instance, saw: %s'
|
||||
% type(self._cells[0]))
|
||||
raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
|
||||
type(self._cells[0]))
|
||||
|
||||
if self._output_is_tuple:
|
||||
self._output_size = tuple(self._cells[0].output_size
|
||||
@ -201,26 +198,36 @@ class GridRNNCell(rnn.RNNCell):
|
||||
if self._output_is_tuple:
|
||||
output = tuple(output_tensors)
|
||||
else:
|
||||
if len(output_tensors) == 0:
|
||||
output = array_ops.zeros([0, 0], dtype)
|
||||
else:
|
||||
if output_tensors:
|
||||
output = array_ops.concat(output_tensors, 1)
|
||||
else:
|
||||
output = array_ops.zeros([0, 0], dtype)
|
||||
|
||||
if self._state_is_tuple:
|
||||
states = tuple(new_state[i] for i in self._config.recurrents)
|
||||
else:
|
||||
# concat each state first, then flatten the whole thing
|
||||
state_tensors = [x for i in self._config.recurrents
|
||||
for x in new_state[i]]
|
||||
if len(state_tensors) == 0:
|
||||
states = array_ops.zeros([0, 0], dtype)
|
||||
else:
|
||||
state_tensors = [
|
||||
x for i in self._config.recurrents for x in new_state[i]
|
||||
]
|
||||
if state_tensors:
|
||||
states = array_ops.concat(state_tensors, 1)
|
||||
else:
|
||||
states = array_ops.zeros([0, 0], dtype)
|
||||
|
||||
return output, states
|
||||
|
||||
def _extract_states(self, state):
|
||||
"""Extract the cell and previous output tensors from the given state
|
||||
"""Extract the cell and previous output tensors from the given state.
|
||||
|
||||
Args:
|
||||
state: The RNN state.
|
||||
|
||||
Returns:
|
||||
Tuple of the cell value, previous output, and cell_output_size.
|
||||
|
||||
Raises:
|
||||
ValueError: If len(self._config.recurrents) != len(state).
|
||||
"""
|
||||
conf = self._config
|
||||
|
||||
@ -238,8 +245,8 @@ class GridRNNCell(rnn.RNNCell):
|
||||
|
||||
if self._state_is_tuple:
|
||||
if len(conf.recurrents) != len(state):
|
||||
raise ValueError("Expected state as a tuple of {} "
|
||||
"element".format(len(conf.recurrents)))
|
||||
raise ValueError('Expected state as a tuple of {} '
|
||||
'element'.format(len(conf.recurrents)))
|
||||
|
||||
for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
|
||||
if cell_output_size > 0:
|
||||
@ -247,8 +254,9 @@ class GridRNNCell(rnn.RNNCell):
|
||||
else:
|
||||
m_prev[recurrent_dim] = recurrent_state
|
||||
else:
|
||||
for recurrent_dim, start_idx in zip(conf.recurrents, range(
|
||||
0, self.state_size, total_cell_state_size)):
|
||||
for recurrent_dim, start_idx in zip(conf.recurrents,
|
||||
range(0, self.state_size,
|
||||
total_cell_state_size)):
|
||||
if cell_output_size > 0:
|
||||
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
||||
[-1, conf.num_units])
|
||||
@ -260,16 +268,25 @@ class GridRNNCell(rnn.RNNCell):
|
||||
return c_prev, m_prev, cell_output_size
|
||||
|
||||
def _project_input(self, inputs, c_prev, m_prev, with_c):
|
||||
"""Fills in c_prev and m_prev with projected input, for input dimensions
|
||||
"""Fills in c_prev and m_prev with projected input, for input dimensions.
|
||||
|
||||
Args:
|
||||
inputs: inputs tensor
|
||||
c_prev: cell value
|
||||
m_prev: previous output
|
||||
with_c: boolean; whether to include project_c.
|
||||
|
||||
Raises:
|
||||
ValueError: if len(self._config.input) != len(inputs)
|
||||
"""
|
||||
conf = self._config
|
||||
|
||||
if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0
|
||||
and len(conf.inputs) > 0):
|
||||
if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
|
||||
conf.inputs):
|
||||
if isinstance(inputs, tuple):
|
||||
if len(conf.inputs) != len(inputs):
|
||||
raise ValueError("Expect inputs as a tuple of {} "
|
||||
"tensors".format(len(conf.inputs)))
|
||||
raise ValueError('Expect inputs as a tuple of {} '
|
||||
'tensors'.format(len(conf.inputs)))
|
||||
input_splits = inputs
|
||||
else:
|
||||
input_splits = array_ops.split(
|
||||
@ -289,7 +306,10 @@ class GridRNNCell(rnn.RNNCell):
|
||||
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
|
||||
|
||||
def _cell_state_size(self):
|
||||
"""Total size of the state of the inner cell used in this grid
|
||||
"""Total size of the state of the inner cell used in this grid.
|
||||
|
||||
Returns:
|
||||
Total size of the state of the inner cell.
|
||||
"""
|
||||
state_sizes = self._cells[0].state_size
|
||||
if isinstance(state_sizes, tuple):
|
||||
@ -306,10 +326,15 @@ class Grid1BasicRNNCell(GridRNNCell):
|
||||
|
||||
def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
|
||||
super(Grid1BasicRNNCell, self).__init__(
|
||||
num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
||||
num_units=num_units,
|
||||
num_dims=1,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=False,
|
||||
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid2BasicRNNCell(GridRNNCell):
|
||||
@ -322,32 +347,50 @@ class Grid2BasicRNNCell(GridRNNCell):
|
||||
specified.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
|
||||
state_is_tuple=True, output_is_tuple=True):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
tied=False,
|
||||
non_recurrent_fn=None,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
super(Grid2BasicRNNCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
num_units=num_units,
|
||||
num_dims=2,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
|
||||
non_recurrent_fn=non_recurrent_fn,
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid1BasicLSTMCell(GridRNNCell):
|
||||
"""1D BasicLSTM cell"""
|
||||
"""1D BasicLSTM cell."""
|
||||
|
||||
def __init__(self, num_units, forget_bias=1,
|
||||
state_is_tuple=True, output_is_tuple=True):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
forget_bias=1,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
def cell_fn(n):
|
||||
return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
|
||||
super(Grid1BasicLSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
||||
cell_fn=lambda n: rnn.BasicLSTMCell(
|
||||
num_units=n, forget_bias=forget_bias),
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
num_units=num_units,
|
||||
num_dims=1,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=False,
|
||||
cell_fn=cell_fn,
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid2BasicLSTMCell(GridRNNCell):
|
||||
"""2D BasicLSTM cell
|
||||
"""2D BasicLSTM cell.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
@ -363,36 +406,53 @@ class Grid2BasicLSTMCell(GridRNNCell):
|
||||
forget_bias=1,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
def cell_fn(n):
|
||||
return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
|
||||
super(Grid2BasicLSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
num_units=num_units,
|
||||
num_dims=2,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n: rnn.BasicLSTMCell(
|
||||
num_units=n, forget_bias=forget_bias),
|
||||
cell_fn=cell_fn,
|
||||
non_recurrent_fn=non_recurrent_fn,
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid1LSTMCell(GridRNNCell):
|
||||
"""1D LSTM cell
|
||||
"""1D LSTM cell.
|
||||
|
||||
This is different from Grid1BasicLSTMCell because it gives options to
|
||||
specify the forget bias and enabling peepholes
|
||||
specify the forget bias and enabling peepholes.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, use_peepholes=False, forget_bias=1.0,
|
||||
state_is_tuple=True, output_is_tuple=True):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
forget_bias=1.0,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
|
||||
def cell_fn(n):
|
||||
return rnn.LSTMCell(
|
||||
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
|
||||
|
||||
super(Grid1LSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=1,
|
||||
input_dims=0, output_dims=0, priority_dims=0,
|
||||
cell_fn=lambda n: rnn.LSTMCell(
|
||||
num_units=n, use_peepholes=use_peepholes,
|
||||
forget_bias=forget_bias),
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
num_units=num_units,
|
||||
num_dims=1,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
cell_fn=cell_fn,
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid2LSTMCell(GridRNNCell):
|
||||
"""2D LSTM cell
|
||||
"""2D LSTM cell.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
@ -408,19 +468,27 @@ class Grid2LSTMCell(GridRNNCell):
|
||||
forget_bias=1.0,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
|
||||
def cell_fn(n):
|
||||
return rnn.LSTMCell(
|
||||
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
|
||||
|
||||
super(Grid2LSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
num_units=num_units,
|
||||
num_dims=2,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n: rnn.LSTMCell(
|
||||
num_units=n, forget_bias=forget_bias,
|
||||
use_peepholes=use_peepholes),
|
||||
cell_fn=cell_fn,
|
||||
non_recurrent_fn=non_recurrent_fn,
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid3LSTMCell(GridRNNCell):
|
||||
"""3D BasicLSTM cell
|
||||
"""3D BasicLSTM cell.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
@ -437,19 +505,27 @@ class Grid3LSTMCell(GridRNNCell):
|
||||
forget_bias=1.0,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
|
||||
def cell_fn(n):
|
||||
return rnn.LSTMCell(
|
||||
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
|
||||
|
||||
super(Grid3LSTMCell, self).__init__(
|
||||
num_units=num_units, num_dims=3,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
num_units=num_units,
|
||||
num_dims=3,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n: rnn.LSTMCell(
|
||||
num_units=n, forget_bias=forget_bias,
|
||||
use_peepholes=use_peepholes),
|
||||
cell_fn=cell_fn,
|
||||
non_recurrent_fn=non_recurrent_fn,
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
class Grid2GRUCell(GridRNNCell):
|
||||
"""2D LSTM cell
|
||||
"""2D LSTM cell.
|
||||
|
||||
This creates a 2D cell which receives input and gives output in the first
|
||||
dimension.
|
||||
@ -457,23 +533,31 @@ class Grid2GRUCell(GridRNNCell):
|
||||
specified.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
|
||||
state_is_tuple=True, output_is_tuple=True):
|
||||
def __init__(self,
|
||||
num_units,
|
||||
tied=False,
|
||||
non_recurrent_fn=None,
|
||||
state_is_tuple=True,
|
||||
output_is_tuple=True):
|
||||
super(Grid2GRUCell, self).__init__(
|
||||
num_units=num_units, num_dims=2,
|
||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
||||
num_units=num_units,
|
||||
num_dims=2,
|
||||
input_dims=0,
|
||||
output_dims=0,
|
||||
priority_dims=0,
|
||||
tied=tied,
|
||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||
cell_fn=lambda n: rnn.GRUCell(num_units=n),
|
||||
non_recurrent_fn=non_recurrent_fn,
|
||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
||||
state_is_tuple=state_is_tuple,
|
||||
output_is_tuple=output_is_tuple)
|
||||
|
||||
|
||||
"""Helpers
|
||||
"""
|
||||
# Helpers
|
||||
|
||||
_GridRNNDimension = namedtuple(
|
||||
'_GridRNNDimension',
|
||||
['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
|
||||
_GridRNNDimension = namedtuple('_GridRNNDimension', [
|
||||
'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
|
||||
])
|
||||
|
||||
_GridRNNConfig = namedtuple('_GridRNNConfig',
|
||||
['num_dims', 'dims', 'inputs', 'outputs',
|
||||
@ -507,8 +591,8 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
|
||||
is_input=(i in input_dims),
|
||||
is_output=(i in output_dims),
|
||||
is_priority=(i in priority_dims),
|
||||
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
|
||||
None))
|
||||
non_recurrent_fn=non_recurrent_fn
|
||||
if i in non_recurrent_dims else None))
|
||||
return _GridRNNConfig(
|
||||
num_dims=num_dims,
|
||||
dims=rnn_dims,
|
||||
@ -544,8 +628,8 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
|
||||
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
|
||||
m_prev[0].dtype)
|
||||
|
||||
last_dim_output = (new_output[-1] if new_output[-1] is not None
|
||||
else m_prev[-1])
|
||||
last_dim_output = (new_output[-1]
|
||||
if new_output[-1] is not None else m_prev[-1])
|
||||
|
||||
for i in dim_indices:
|
||||
d = conf.dims[i]
|
||||
@ -563,8 +647,8 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
|
||||
linear_args,
|
||||
num_outputs=conf.num_units,
|
||||
activation_fn=d.non_recurrent_fn,
|
||||
weights_initializer=vs.get_variable_scope().initializer or
|
||||
layers.initializers.xavier_initializer,
|
||||
weights_initializer=(vs.get_variable_scope().initializer or
|
||||
layers.initializers.xavier_initializer),
|
||||
weights_regularizer=vs.get_variable_scope().regularizer)
|
||||
else:
|
||||
if c_prev[i] is not None:
|
||||
|
@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
using functor::FillProjectiveTransform;
|
||||
using generator::INTERPOLATION_BILINEAR;
|
||||
using generator::INTERPOLATION_NEAREST;
|
||||
using generator::Interpolation;
|
||||
using generator::ProjectiveGenerator;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ImageProjectiveTransform : public OpKernel {
|
||||
private:
|
||||
Interpolation interpolation_;
|
||||
|
||||
public:
|
||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
string interpolation_str;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
|
||||
if (interpolation_str == "NEAREST") {
|
||||
interpolation_ = INTERPOLATION_NEAREST;
|
||||
} else if (interpolation_str == "BILINEAR") {
|
||||
interpolation_ = INTERPOLATION_BILINEAR;
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid interpolation " << interpolation_str
|
||||
<< ". Supported types: NEAREST, BILINEAR";
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& images_t = ctx->input(0);
|
||||
@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
|
||||
Tensor* output_t;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
||||
auto output = output_t->tensor<T, 4>();
|
||||
const FillProjectiveTransform<Device, T> functor;
|
||||
functor(ctx->eigen_device<Device>(), &output, images, transform);
|
||||
(FillProjectiveTransform<Device, T>(interpolation_))(
|
||||
ctx->eigen_device<Device>(), &output, images, transform);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -28,6 +28,8 @@ namespace tensorflow {
|
||||
|
||||
namespace generator {
|
||||
|
||||
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
|
||||
|
||||
using Eigen::array;
|
||||
using Eigen::DenseIndex;
|
||||
|
||||
@ -36,20 +38,19 @@ class ProjectiveGenerator {
|
||||
private:
|
||||
typename TTypes<T, 4>::ConstTensor input_;
|
||||
typename TTypes<float>::ConstMatrix transforms_;
|
||||
const Interpolation interpolation_;
|
||||
|
||||
public:
|
||||
static const int kNumParameters = 8;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
||||
typename TTypes<float>::ConstMatrix transforms)
|
||||
: input_(input), transforms_(transforms) {}
|
||||
typename TTypes<float>::ConstMatrix transforms,
|
||||
const Interpolation interpolation)
|
||||
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
operator()(const array<DenseIndex, 4>& coords) const {
|
||||
array<DenseIndex, 4> input_coords;
|
||||
input_coords[0] = coords[0];
|
||||
|
||||
const int64 output_y = coords[1];
|
||||
const int64 output_x = coords[2];
|
||||
const float* transform =
|
||||
@ -57,24 +58,73 @@ class ProjectiveGenerator {
|
||||
? transforms_.data()
|
||||
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
||||
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
||||
const int64 input_x = std::round(
|
||||
const float input_x =
|
||||
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
||||
projection);
|
||||
const int64 input_y = std::round(
|
||||
projection;
|
||||
const float input_y =
|
||||
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
||||
projection);
|
||||
projection;
|
||||
|
||||
if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
|
||||
input_x < input_.dimension(2))) {
|
||||
// TODO(ringwalt): Add a fill value input.
|
||||
return T(0);
|
||||
static const T fill_value = T(0);
|
||||
switch (interpolation_) {
|
||||
case INTERPOLATION_NEAREST:
|
||||
// Switch the order of x and y again for indexing into the image.
|
||||
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
|
||||
fill_value);
|
||||
case INTERPOLATION_BILINEAR:
|
||||
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
|
||||
fill_value);
|
||||
}
|
||||
}
|
||||
input_coords[1] = input_y;
|
||||
input_coords[2] = input_x;
|
||||
|
||||
input_coords[3] = coords[3];
|
||||
private:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
nearest_interpolation(const DenseIndex batch, const float y, const float x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
return read_with_fill_value(batch, DenseIndex(std::round(y)),
|
||||
DenseIndex(std::round(x)), channel, fill_value);
|
||||
}
|
||||
|
||||
return input_(input_coords);
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
const float y_floor = std::floor(y);
|
||||
const float x_floor = std::floor(x);
|
||||
const float y_ceil = y_floor + 1;
|
||||
const float x_ceil = x_floor + 1;
|
||||
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
|
||||
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
|
||||
const float value_yfloor =
|
||||
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||
DenseIndex(x_floor), channel,
|
||||
fill_value) +
|
||||
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||
DenseIndex(x_ceil), channel,
|
||||
fill_value);
|
||||
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
|
||||
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
|
||||
const float value_yceil =
|
||||
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||
DenseIndex(x_floor), channel,
|
||||
fill_value) +
|
||||
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||
DenseIndex(x_ceil), channel,
|
||||
fill_value);
|
||||
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
|
||||
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
|
||||
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
|
||||
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
|
||||
const DenseIndex channel, const T fill_value) const {
|
||||
// batch and channel must be correct, because they are passed unchanged from
|
||||
// the input.
|
||||
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
|
||||
x < input_.dimension(2))
|
||||
? input_(array<DenseIndex, 4>{batch, y, x, channel})
|
||||
: fill_value;
|
||||
}
|
||||
};
|
||||
|
||||
@ -85,6 +135,7 @@ class ProjectiveGenerator {
|
||||
// some Eigen device code.
|
||||
namespace functor {
|
||||
|
||||
using generator::Interpolation;
|
||||
using generator::ProjectiveGenerator;
|
||||
|
||||
template <typename Device, typename T>
|
||||
@ -92,15 +143,17 @@ struct FillProjectiveTransform {
|
||||
typedef typename TTypes<T, 4>::Tensor OutputType;
|
||||
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
||||
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
||||
const Interpolation interpolation_;
|
||||
|
||||
FillProjectiveTransform() {}
|
||||
FillProjectiveTransform(Interpolation interpolation)
|
||||
: interpolation_(interpolation) {}
|
||||
|
||||
EIGEN_ALWAYS_INLINE
|
||||
void operator()(const Device& device, OutputType* output,
|
||||
const InputType& images,
|
||||
const TransformsType& transform) const {
|
||||
ProjectiveGenerator<Device, T> generator(images, transform);
|
||||
output->device(device) = images.generate(generator);
|
||||
output->device(device) = images.generate(
|
||||
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
|
||||
|
||||
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
||||
// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
|
||||
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
||||
// implement "same" and "valid" modes in the Python function.
|
||||
REGISTER_OP("ImageProjectiveTransform")
|
||||
.Input("images: dtype")
|
||||
.Input("transforms: float32")
|
||||
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
||||
.Attr("interpolation: string")
|
||||
.Output("transformed_images: dtype")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->input(0));
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -111,6 +112,79 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
|
||||
[0, 1, 0, 1],
|
||||
[0, 1, 1, 1]])
|
||||
|
||||
def test_bilinear(self):
|
||||
with self.test_session():
|
||||
image = constant_op.constant(
|
||||
[[0, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 1, 0, 1, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
dtypes.float32)
|
||||
# The following result matches:
|
||||
# >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
|
||||
# which uses spline interpolation of order 1, equivalent to bilinear
|
||||
# interpolation.
|
||||
self.assertAllClose(
|
||||
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||
[[0.000, 0.000, 0.343, 0.000, 0.000],
|
||||
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||
[0.343, 0.914, 0.000, 0.914, 0.343],
|
||||
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||
[0.000, 0.000, 0.343, 0.000, 0.000]],
|
||||
atol=0.001)
|
||||
self.assertAllClose(
|
||||
image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
|
||||
[[0, 0, 1, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[1, 1, 0, 1, 1],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 1, 0, 0]])
|
||||
|
||||
def test_bilinear_uint8(self):
|
||||
with self.test_session():
|
||||
image = constant_op.constant(
|
||||
np.asarray(
|
||||
[[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 255, 255, 255, 0.0],
|
||||
[0.0, 255, 0.0, 255, 0.0],
|
||||
[0.0, 255, 255, 255, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||
np.uint8),
|
||||
dtypes.uint8)
|
||||
# == np.rint((expected image above) * 255)
|
||||
self.assertAllEqual(
|
||||
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||
[[0.0, 0.0, 87., 0.0, 0.0],
|
||||
[0.0, 149, 233, 149, 0.0],
|
||||
[87., 233, 0.0, 233, 87.],
|
||||
[0.0, 149, 233, 149, 0.0],
|
||||
[0.0, 0.0, 87., 0.0, 0.0]])
|
||||
|
||||
def _test_grad(self, shape_to_test):
|
||||
with self.test_session():
|
||||
test_image_shape = shape_to_test
|
||||
test_image = np.random.randn(*test_image_shape)
|
||||
test_image_tensor = constant_op.constant(
|
||||
test_image, shape=test_image_shape)
|
||||
test_transform = image_ops.angles_to_projective_transforms(
|
||||
np.pi / 2, 4, 4)
|
||||
|
||||
output_shape = test_image_shape
|
||||
output = image_ops.transform(test_image_tensor, test_transform)
|
||||
left_err = gradient_checker.compute_gradient_error(
|
||||
test_image_tensor,
|
||||
test_image_shape,
|
||||
output,
|
||||
output_shape,
|
||||
x_init_value=test_image)
|
||||
self.assertLess(left_err, 1e-10)
|
||||
|
||||
def test_grad(self):
|
||||
self._test_grad([16, 16])
|
||||
self._test_grad([4, 12, 12])
|
||||
self._test_grad([3, 4, 12, 12])
|
||||
|
||||
|
||||
def _test_grad(self, shape_to_test):
|
||||
with self.test_session():
|
||||
|
@ -24,8 +24,8 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_image_ops_so = loader.load_op_library(
|
||||
@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
|
||||
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
def rotate(images, angles):
|
||||
def rotate(images, angles, interpolation="NEAREST"):
|
||||
"""Rotate image(s) by the passed angle(s) in radians.
|
||||
|
||||
Args:
|
||||
@ -46,6 +46,7 @@ def rotate(images, angles):
|
||||
(num_rows, num_columns) (HW).
|
||||
angles: A scalar angle to rotate all images by, or (if images has rank 4)
|
||||
a vector of length num_images, with an angle for each image in the batch.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
|
||||
Returns:
|
||||
Image(s) with the same type and shape as `images`, rotated by the given
|
||||
@ -70,7 +71,8 @@ def rotate(images, angles):
|
||||
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
|
||||
output = transform(
|
||||
images,
|
||||
angles_to_projective_transforms(angles, image_width, image_height))
|
||||
angles_to_projective_transforms(angles, image_height, image_width),
|
||||
interpolation=interpolation)
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return output[0, :, :, 0]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
|
||||
axis=1)
|
||||
|
||||
|
||||
def transform(images, transforms):
|
||||
def transform(images, transforms, interpolation="NEAREST"):
|
||||
"""Applies the given transform(s) to the image(s).
|
||||
|
||||
Args:
|
||||
@ -134,6 +136,7 @@ def transform(images, transforms):
|
||||
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
||||
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
||||
the transform mapping input points to output points.
|
||||
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||
|
||||
Returns:
|
||||
Image(s) with the same type and shape as `images`, with the given
|
||||
@ -163,8 +166,8 @@ def transform(images, transforms):
|
||||
transforms = transform_or_transforms
|
||||
else:
|
||||
raise TypeError("Transforms should have rank 1 or 2.")
|
||||
# pylint: disable=protected-access
|
||||
output = gen_image_ops.image_projective_transform(images, transforms)
|
||||
output = gen_image_ops.image_projective_transform(
|
||||
images, transforms, interpolation=interpolation.upper())
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return output[0, :, :, 0]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
@ -217,8 +220,10 @@ def _transform_matrices_to_flat(transform_matrices):
|
||||
|
||||
@ops.RegisterGradient("ImageProjectiveTransform")
|
||||
def _image_projective_transform_grad(op, grad):
|
||||
"""Computes the gradient for ImageProjectiveTransform."""
|
||||
images = op.inputs[0]
|
||||
transforms = op.inputs[1]
|
||||
interpolation = op.get_attr("interpolation")
|
||||
|
||||
image_or_images = ops.convert_to_tensor(images, name="images")
|
||||
transform_or_transforms = ops.convert_to_tensor(
|
||||
@ -245,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
|
||||
transforms = _flat_transforms_to_matrices(transforms=transforms)
|
||||
inverse = linalg_ops.matrix_inverse(transforms)
|
||||
transforms = _transform_matrices_to_flat(inverse)
|
||||
output = gen_image_ops.image_projective_transform(grad, transforms)
|
||||
output = gen_image_ops.image_projective_transform(
|
||||
grad, transforms, interpolation=interpolation)
|
||||
if len(image_or_images.get_shape()) == 2:
|
||||
return [output[0, :, :, 0], None]
|
||||
elif len(image_or_images.get_shape()) == 3:
|
||||
|
55
tensorflow/contrib/kernel_methods/README.md
Normal file
55
tensorflow/contrib/kernel_methods/README.md
Normal file
@ -0,0 +1,55 @@
|
||||
# TensorFlow contrib kernel_methods.
|
||||
|
||||
This module contains operations and estimators that enable the use of primal
|
||||
(explicit) kernel methods in TensorFlow. See also the [tutorial](https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/g3doc/tutorial.md) on how to use this module to improve the quality of
|
||||
classification or regression tasks.
|
||||
|
||||
## Kernel Mappers
|
||||
Implement explicit kernel mapping Ops over tensors. Kernel mappers add
|
||||
Tensor-In-Tensor-Out (TITO) Ops to the TensorFlow graph. They can be used in
|
||||
conjunction with other layers or ML models.
|
||||
|
||||
Sample usage:
|
||||
|
||||
```python
|
||||
kernel_mapper = tf.contrib.kernel_methods.SomeKernelMapper(...)
|
||||
out_tensor = kernel_mapper.map(in_tensor)
|
||||
... # code that consumes out_tensor.
|
||||
```
|
||||
|
||||
Currently, there is a [RandomFourierFeatureMapper]
|
||||
(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py) implemented that maps dense
|
||||
input to dense output.
|
||||
|
||||
## Kernel-based Estimators
|
||||
tf.contrib.learn Estimators that use kernel mappers internally to discover
|
||||
non-linearities in the data. These canned estimators map their input features
|
||||
using kernel mapper Ops and then apply linear models to the mapped
|
||||
features. Combining kernel mappers with linear models and different loss
|
||||
functions leads to a variety of models: linear and non-linear SVMs, linear
|
||||
regression (with and without kernels) and (multinomial) logistic regression
|
||||
(with and without kernels).
|
||||
|
||||
Currently there is a [KernelLinearClassifier]
|
||||
(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/kernel_estimators.py) implemented but more pre-packaged estimators
|
||||
are on the way.
|
||||
|
||||
Sample usage:
|
||||
|
||||
```python
|
||||
real_column_a = tf.contrib.layers.real_valued_column(name='real_column_a',...)
|
||||
sparse_column_b = tf.contrib.layers.sparse_column_with_hash_bucket(...)
|
||||
kernel_mappers = {real_column_a : [tf.contrib.kernel_methods.SomeKernelMapper(...)]}
|
||||
optimizer = ...
|
||||
|
||||
kernel_classifier = tf.contrib.kernel_methods.KernelLinearClassifier(
|
||||
feature_columns=[real_column_a, sparse_column_b],
|
||||
model_dir=...,
|
||||
optimizer=optimizer,
|
||||
kernel_mappers=kernel_mappers)
|
||||
|
||||
# Construct input_fns
|
||||
kernel_classifier.fit(...)
|
||||
kernel_classifier.evaluate(...)
|
||||
```
|
||||
|
BIN
tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png
Normal file
BIN
tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
BIN
tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png
Normal file
BIN
tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
BIN
tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg
Normal file
BIN
tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.2 KiB |
273
tensorflow/contrib/kernel_methods/g3doc/tutorial.md
Normal file
273
tensorflow/contrib/kernel_methods/g3doc/tutorial.md
Normal file
@ -0,0 +1,273 @@
|
||||
# Improving classification using explicit kernel methods
|
||||
|
||||
In this tutorial, we demonstrate how combining (explicit) kernel methods with
|
||||
linear models can drastically increase the latters' quality of predictions
|
||||
without significantly increasing training and inference times. Currently,
|
||||
explicit kernel mappings are supported for dense features. Support for sparse
|
||||
features is in the works.
|
||||
|
||||
We will use [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models.
|
||||
tf.contrib.learn API reduces the boilerplate code one needs to write for
|
||||
configuring, training and evaluating models and will let us focus on the core
|
||||
ideas. If you are not familiar with this API, [tf.contrib.learn Quickstart](https://www.tensorflow.org/get_started/tflearn) is a good place to start. We
|
||||
will use MNIST, a widely-used dataset containing images of handwritten digits
|
||||
(between 0 and 9). The tutorial consists of the following steps:
|
||||
|
||||
* Load and prepare MNIST data for classification.
|
||||
* Construct a simple linear model, train it and evaluate it on the eval data.
|
||||
* Replace the linear model with a kernelized linear model, re-train and
|
||||
re-evaluate.
|
||||
|
||||
## Load and prepare MNIST data for classification
|
||||
The first step is to prepare the data to be fed to the ML models. The following
|
||||
utility command from tf.contrib.learn loads the MNIST dataset:
|
||||
|
||||
```python
|
||||
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||
```
|
||||
This loads the entire MNIST dataset (containing 70K samples) and splits it into
|
||||
train, validation and test data with 55K, 5K and 10K samples respectively. Each
|
||||
split contains one numpy array for images (with shape [sample_size, 784]) and
|
||||
one for labels (with shape [sample_size, 1]). In this tutorial, we only use the
|
||||
train and validation splits (to train and evaluate our models respectively).
|
||||
|
||||
In order to feed data to a tf.contrib.learn Estimator, it is helpful to convert
|
||||
it to Tensors. For this, we will use an `input function` which adds Ops to the
|
||||
TensorFlow graph that, when executed, create mini-batches of Tensors to be used
|
||||
downstream. For more background on input functions, check
|
||||
[Building Input Functions with tf.contrib.learn](https://www.tensorflow.org/get_started/input_fn). In this example, we will use the `tf.train.shuffle_batch` Op which,
|
||||
besides converting numpy arrays to Tensors, allows us to specify the batch_size
|
||||
and whether to randomize the input every time the input_fn Ops are executed
|
||||
(randomization typically expedites convergence during training). The full code
|
||||
for loading and preparing the data is shown in the snippet below. In this
|
||||
example, we use mini-batches of size 256 for training and the entire sample (5K
|
||||
entries) for evaluation. Feel free to experiment with different batch sizes.
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
|
||||
|
||||
def _input_fn():
|
||||
images_batch, labels_batch = tf.train.shuffle_batch(
|
||||
tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
|
||||
batch_size=batch_size,
|
||||
capacity=capacity,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
enqueue_many=True,
|
||||
num_threads=4)
|
||||
features_map = {'images': images_batch}
|
||||
return features_map, labels_batch
|
||||
|
||||
return _input_fn
|
||||
|
||||
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||
|
||||
train_input_fn = get_input_fn(data.train, batch_size=256)
|
||||
eval_input_fn = get_input_fn(data.validation, batch_size=5000)
|
||||
|
||||
```
|
||||
|
||||
## Training a simple linear model
|
||||
We can now train a linear model over the MNIST dataset. We will use the [tf.contrib.learn.LinearClassifier](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py) estimator with 10 classes (representing the 10
|
||||
digits). The input features form a 784-dimensional (dense) vector which can be
|
||||
specified as follows:
|
||||
|
||||
```python
|
||||
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
|
||||
```
|
||||
|
||||
The full code for constructing, training and evaluating a LinearClassifier
|
||||
estimator is shown below.
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
# Specify the feature(s) to be used by the estimator.
|
||||
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
|
||||
estimator = tf.contrib.learn.LinearClassifier(feature_columns=[image_column], n_classes=10)
|
||||
|
||||
# Train.
|
||||
start = time.time()
|
||||
estimator.fit(input_fn=train_input_fn, steps=2000)
|
||||
end = time.time()
|
||||
print('Elapsed time: {} seconds'.format(end - start))
|
||||
|
||||
# Evaluate and report metrics.
|
||||
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
|
||||
print(eval_metrics)
|
||||
```
|
||||
On eval data, the loss (i.e., the value of the objective function being
|
||||
minimized during training) lies between **0.25 and 0.30** (depending on the
|
||||
parameters used) while the accuracy of the classifier is approximately **92.5%**
|
||||
(training is randomized so the exact loss and accuracy will vary). Also, the
|
||||
training time is around 25 seconds (this will also vary based on the machine you
|
||||
run the code on).
|
||||
|
||||
In addition to experimenting with the (training) batch size and the number of
|
||||
training steps, there are a couple other parameters that can be tuned as well.
|
||||
For instance, you can change the optimization method used to minimize the loss
|
||||
by explicitly selecting another optimizer from the collection of
|
||||
[available optimizers]
|
||||
(https://www.tensorflow.org/code/tensorflow/python/training).
|
||||
As an example, the following code constructs a LinearClassifer estimator that
|
||||
uses the Follow-The-Regularized-Leader (FTRL) optimization strategy with a
|
||||
specific learning rate and l2-regularization.
|
||||
|
||||
|
||||
```python
|
||||
optimizer = tf.train.FtrlOptimizer(learning_rate=5.0, l2_regularization_strength=1.0)
|
||||
estimator = tf.contrib.learn.LinearClassifier(
|
||||
feature_columns=[image_column], n_classes=10, optimizer=optimizer)
|
||||
```
|
||||
|
||||
Regardless of the values of the parameters, the max accuracy a linear model can
|
||||
achieve on this dataset caps at around **93%**.
|
||||
|
||||
## Using explicit kernel mappings with the linear model.
|
||||
The relatively high error (~7%) of the linear model over MNIST indicates that
|
||||
the input data is not linearly separable. We will use explicit kernel mappings
|
||||
to reduce the classification error.
|
||||
|
||||
**Intuition:** The high-level idea is to use a non-linear map to transform the
|
||||
input space to another feature space (of possibly higher dimension) where the
|
||||
(transformed) features are (almost) linearly separable and then apply a linear
|
||||
model on the mapped features. This is shown in the following figure:
|
||||
|
||||
<div style="text-align:center">
|
||||
<img src="./kernel_mapping.png">
|
||||
</div>
|
||||
|
||||
**Technical details overview:** In this example we will use **Random Fourier
|
||||
Features** (introduced in the ["Random Features for Large-Scale Kernel Machines"]
|
||||
(https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) paper by
|
||||
Rahimi and Recht) to map the input data. Random Fourier Features map a vector
|
||||
$$\mathbf{x} \in \mathbb{R}^d$$ to $$\mathbf{x'} \in \mathbb{R}^D$$ via the
|
||||
following mapping:
|
||||
|
||||
$$
|
||||
RFFM(\cdot): \mathbb{R}^d \to \mathbb{R}^D, \quad
|
||||
RFFM(\mathbf{x}) = \cos(\mathbf{\Omega} \cdot \mathbf{x}+ \mathbf{b})
|
||||
$$
|
||||
|
||||
where $$\mathbf{\Omega} \in \mathbb{R}^{D \times d}$$,
|
||||
$$\mathbf{x} \in \mathbb{R}^d,$$ $$\mathbf{b} \in \mathbb{R}^D$$ and cosine is
|
||||
applied elementwise.
|
||||
|
||||
In this example, the entries of $$\mathbf{\Omega}$$ and $$\mathbf{b}$$ are
|
||||
sampled from distributions such that the mapping satisfies the following
|
||||
property:
|
||||
|
||||
$$
|
||||
RFFM(\mathbf{x})^T \cdot RFFM(\mathbf{y}) \approx
|
||||
e^{-\frac{\|\mathbf{x} - \mathbf{y}\|^2}{2 \sigma^2}}
|
||||
$$
|
||||
|
||||
The right-hand-side quantity of the expression above is known as the RBF (or
|
||||
Gaussian) kernel function. This function is one of the most-widely used kernel
|
||||
functions in Machine Learning and measures (implicitly) similarity in a
|
||||
different (much higher dimensional) space than the original one. See
|
||||
[Radial basis function kernel](https://en.wikipedia.org/wiki/Radial_basis_function_kernel)
|
||||
for more details.
|
||||
|
||||
**Kernel Classifier:** `tf.contrib.kernel_methods.KernelLinearClassifier` is a
|
||||
pre-packaged `tf.contrib.learn` estimator that combines the power of explicit
|
||||
kernel mappings with linear models. Its API is very similar to that of the
|
||||
LinearClassifier with the additional ability to specify a list of explicit
|
||||
kernel mappings to be apply to each feature used by the classifier. The
|
||||
following code snippet demonstrates how to replace LinearClassifier with
|
||||
KernelLinearClassifier.
|
||||
|
||||
|
||||
```python
|
||||
# Specify the feature(s) to be used by the estimator. This is identical to the
|
||||
# code used for the LinearClassifier.
|
||||
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
|
||||
optimizer = tf.train.FtrlOptimizer(
|
||||
learning_rate=50.0, l2_regularization_strength=0.001)
|
||||
|
||||
|
||||
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
|
||||
input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
|
||||
kernel_mappers = {image_column: [kernel_mapper]}
|
||||
estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
|
||||
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
|
||||
|
||||
# Train.
|
||||
start = time.time()
|
||||
estimator.fit(input_fn=train_input_fn, steps=2000)
|
||||
end = time.time()
|
||||
print('Elapsed time: {} seconds'.format(end - start))
|
||||
|
||||
# Evaluate and report metrics.
|
||||
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
|
||||
print(eval_metrics)
|
||||
```
|
||||
The only additional parameter passed to `KernelLinearClassifier` is a dictionary
|
||||
from feature_columns to a list of kernel mappings to be applied to the
|
||||
corresponding feature column. In this example, the lines
|
||||
|
||||
```python
|
||||
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
|
||||
input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
|
||||
kernel_mappers = {image_column: [kernel_mapper]}
|
||||
estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
|
||||
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
|
||||
```
|
||||
instruct the classifier to first map the initial 784-dimensional images to
|
||||
2000-dimensional vectors using random Fourier features and then learn a linear
|
||||
model on the transformed vectors. Note that, besides the output dimension, there
|
||||
is one more parameter (stddev) involved. This parameter is the standard
|
||||
deviation ($$\sigma$$) of the approximated RBF kernel and controls the
|
||||
similarity measure used in classification. This parameter is typically
|
||||
determined via hyperparameter tuning.
|
||||
|
||||
Running the code above yields a loss of approximately **0.10** while the
|
||||
accuracy is increased to approximately **97%** on eval data (an increase of 4%
|
||||
over the plain linear model). The training time hovers around 35 seconds. We can
|
||||
increase the accuracy even more, by increasing the output dimension of the
|
||||
mapping and tuning the standard deviation even more.
|
||||
|
||||
**On the role of stddev:** The classification quality is very sensitive to the
|
||||
value of the stddev parameter used to define the similarity measure between the
|
||||
pairs of input features. The following table shows the accuracy of the
|
||||
classifier on the eval data for different values of stddev (for all experiments
|
||||
the output dimension was fixed to 3000). The optimal value is stddev=5.0. Notice
|
||||
how too small or too high stddev values can dramatically decrease the accuracy
|
||||
of the classification.
|
||||
|
||||
stddev | eval accuracy
|
||||
:----- | :------------
|
||||
1.0 | 0.1362
|
||||
2.0 | 0.4764
|
||||
4.0 | 0.9654
|
||||
5.0 | 0.9766
|
||||
8.0 | 0.9714
|
||||
16.0 | 0.8878
|
||||
|
||||
**On the role of the output dimension:** Intuitively, the larger the output
|
||||
dimension of the mapping, the closer the inner product of two mapped vectors
|
||||
approximates the kernel which typically translates to better classification
|
||||
accuracy. Another way to think about this is that the output dimension equals
|
||||
the number of weights of the linear model (the larger this dimension, the larger
|
||||
the "degrees of freedom" of the model). However, after a certain threshold,
|
||||
higher output dimensions increase the accuracy by very little (while still
|
||||
increasing the training time). This is shown in the following 2 Figures which
|
||||
depict the eval accuracy as a function of the output dimension and the training
|
||||
time respectively.
|
||||
|
||||
 
|
||||
|
||||
|
||||
## Explicit Kernel Mappings: summary and practical tips
|
||||
* Explicit kernel mappings combine the predictive power of non-linear models
|
||||
with the scalability of linear models.
|
||||
* Random Fourier Features can be particularly effective for datasets with dense
|
||||
features.
|
||||
* The parameters of the kernel mapping are often data-dependent. Model quality
|
||||
can be very sensitive to these parameters. Use hyperparameter tuning to find the
|
||||
optimal values.
|
||||
* If you have multiple numerical features, concatinate them into a single
|
||||
multi-dimensional one and apply the kernel mapping to the concatenated vector.
|
||||
|
@ -131,21 +131,27 @@ import math
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.layers.python.layers import embedding_ops
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
|
||||
from tensorflow.python.feature_column import feature_column as fc_core
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
@ -291,11 +297,13 @@ class _FeatureColumn(object):
|
||||
|
||||
|
||||
# TODO(b/30410315): Support warm starting in all feature columns.
|
||||
class _SparseColumn(_FeatureColumn,
|
||||
collections.namedtuple("_SparseColumn",
|
||||
["column_name", "is_integerized",
|
||||
"bucket_size", "lookup_config",
|
||||
"combiner", "dtype"])):
|
||||
class _SparseColumn(
|
||||
_FeatureColumn,
|
||||
fc_core._CategoricalColumn, # pylint: disable=protected-access
|
||||
collections.namedtuple("_SparseColumn", [
|
||||
"column_name", "is_integerized", "bucket_size", "lookup_config",
|
||||
"combiner", "dtype"
|
||||
])):
|
||||
"""Represents a sparse feature column also known as categorical features.
|
||||
|
||||
Instances of this class are immutable. A sparse column means features are
|
||||
@ -426,9 +434,8 @@ class _SparseColumn(_FeatureColumn,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
combiner=self.combiner)
|
||||
|
||||
def _get_input_sparse_tensor(self, columns_to_tensors):
|
||||
"""Looks up the input tensor for transformation and sparsify it if dense."""
|
||||
input_tensor = columns_to_tensors[self.name]
|
||||
def _get_input_sparse_tensor(self, input_tensor):
|
||||
"""sparsify input_tensor if dense."""
|
||||
if not isinstance(input_tensor, sparse_tensor_py.SparseTensor):
|
||||
# To avoid making any assumptions about which values are to be ignored,
|
||||
# we set ignore_value to -1 for numeric tensors to avoid excluding valid
|
||||
@ -455,18 +462,44 @@ class _SparseColumn(_FeatureColumn,
|
||||
format(self.name, other_column.name))
|
||||
return compatible
|
||||
|
||||
@abc.abstractmethod
|
||||
def _do_transform(self, input_tensor):
|
||||
pass
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Handles sparse column to id conversion."""
|
||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors[self.name])
|
||||
columns_to_tensors[self] = self._do_transform(input_tensor)
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
input_tensor = self._get_input_sparse_tensor(inputs.get(self.name))
|
||||
return self._do_transform(input_tensor)
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return self.config
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
return self.length
|
||||
|
||||
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
del weight_collections
|
||||
del trainable
|
||||
input_tensor = inputs.get(self)
|
||||
return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
|
||||
self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
|
||||
|
||||
|
||||
class _SparseColumnIntegerized(_SparseColumn):
|
||||
"""See `sparse_column_with_integerized_feature`."""
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Handles sparse column to id conversion."""
|
||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
||||
|
||||
def _do_transform(self, input_tensor):
|
||||
sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size,
|
||||
name="mod")
|
||||
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
||||
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
|
||||
return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
|
||||
input_tensor.dense_shape)
|
||||
|
||||
|
||||
def sparse_column_with_integerized_feature(column_name,
|
||||
@ -517,10 +550,7 @@ def sparse_column_with_integerized_feature(column_name,
|
||||
class _SparseColumnHashed(_SparseColumn):
|
||||
"""See `sparse_column_with_hash_bucket`."""
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Handles sparse column to id conversion."""
|
||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
||||
|
||||
def _do_transform(self, input_tensor):
|
||||
if self.dtype.is_integer:
|
||||
sparse_values = string_ops.as_string(input_tensor.values)
|
||||
else:
|
||||
@ -528,8 +558,8 @@ class _SparseColumnHashed(_SparseColumn):
|
||||
|
||||
sparse_id_values = string_ops.string_to_hash_bucket_fast(
|
||||
sparse_values, self.bucket_size, name="lookup")
|
||||
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
||||
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
|
||||
return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
|
||||
input_tensor.dense_shape)
|
||||
|
||||
|
||||
def sparse_column_with_hash_bucket(column_name,
|
||||
@ -572,16 +602,13 @@ def sparse_column_with_hash_bucket(column_name,
|
||||
class _SparseColumnKeys(_SparseColumn):
|
||||
"""See `sparse_column_with_keys`."""
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Handles sparse column to id conversion."""
|
||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
||||
|
||||
def _do_transform(self, input_tensor):
|
||||
table = lookup.index_table_from_tensor(
|
||||
mapping=tuple(self.lookup_config.keys),
|
||||
default_value=self.lookup_config.default_value,
|
||||
dtype=self.dtype,
|
||||
name="lookup")
|
||||
columns_to_tensors[self] = table.lookup(input_tensor)
|
||||
return table.lookup(input_tensor)
|
||||
|
||||
|
||||
def sparse_column_with_keys(
|
||||
@ -621,9 +648,7 @@ def sparse_column_with_keys(
|
||||
class _SparseColumnVocabulary(_SparseColumn):
|
||||
"""See `sparse_column_with_vocabulary_file`."""
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Handles sparse column to id conversion."""
|
||||
st = self._get_input_sparse_tensor(columns_to_tensors)
|
||||
def _do_transform(self, st):
|
||||
if self.dtype.is_integer:
|
||||
sparse_string_values = string_ops.as_string(st.values)
|
||||
sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices,
|
||||
@ -638,7 +663,7 @@ class _SparseColumnVocabulary(_SparseColumn):
|
||||
vocab_size=self.lookup_config.vocab_size,
|
||||
default_value=self.lookup_config.default_value,
|
||||
name=self.name + "_lookup")
|
||||
columns_to_tensors[self] = table.lookup(sparse_string_tensor)
|
||||
return table.lookup(sparse_string_tensor)
|
||||
|
||||
|
||||
def sparse_column_with_vocabulary_file(column_name,
|
||||
@ -694,9 +719,12 @@ def sparse_column_with_vocabulary_file(column_name,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
||||
"_WeightedSparseColumn",
|
||||
["sparse_id_column", "weight_column_name", "dtype"])):
|
||||
class _WeightedSparseColumn(
|
||||
_FeatureColumn,
|
||||
fc_core._CategoricalColumn, # pylint: disable=protected-access
|
||||
collections.namedtuple("_WeightedSparseColumn",
|
||||
["sparse_id_column", "weight_column_name",
|
||||
"dtype"])):
|
||||
"""See `weighted_sparse_column`."""
|
||||
|
||||
def __new__(cls, sparse_id_column, weight_column_name, dtype):
|
||||
@ -725,22 +753,6 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
||||
"""Returns a string which will be used as a key when we do sorting."""
|
||||
return "{}".format(self)
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Inserts a tuple with the id and weight tensors."""
|
||||
if self.sparse_id_column not in columns_to_tensors:
|
||||
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
|
||||
|
||||
weight_tensor = columns_to_tensors[self.weight_column_name]
|
||||
if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
|
||||
# The weight tensor can be a regular Tensor. In such case, sparsify it.
|
||||
weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
|
||||
if not self.dtype.is_floating:
|
||||
weight_tensor = math_ops.to_float(weight_tensor)
|
||||
columns_to_tensors[self] = tuple([
|
||||
columns_to_tensors[self.sparse_id_column],
|
||||
weight_tensor
|
||||
])
|
||||
|
||||
def id_tensor(self, input_tensor):
|
||||
"""Returns the id tensor from the given transformed input_tensor."""
|
||||
return input_tensor[0]
|
||||
@ -768,6 +780,43 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
combiner=self.sparse_id_column.combiner)
|
||||
|
||||
def _do_transform(self, id_tensor, weight_tensor):
|
||||
if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
|
||||
# The weight tensor can be a regular Tensor. In such case, sparsify it.
|
||||
weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
|
||||
if not self.dtype.is_floating:
|
||||
weight_tensor = math_ops.to_float(weight_tensor)
|
||||
return tuple([id_tensor, weight_tensor])
|
||||
|
||||
def insert_transformed_feature(self, columns_to_tensors):
|
||||
"""Inserts a tuple with the id and weight tensors."""
|
||||
if self.sparse_id_column not in columns_to_tensors:
|
||||
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
|
||||
|
||||
weight_tensor = columns_to_tensors[self.weight_column_name]
|
||||
columns_to_tensors[self] = self._do_transform(
|
||||
columns_to_tensors[self.sparse_id_column], weight_tensor)
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
return self._do_transform(
|
||||
inputs.get(self.sparse_id_column), inputs.get(self.weight_column_name))
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return self.config
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
return self.length
|
||||
|
||||
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
del weight_collections
|
||||
del trainable
|
||||
input_tensor = inputs.get(self)
|
||||
return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
|
||||
self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
|
||||
|
||||
|
||||
def weighted_sparse_column(sparse_id_column,
|
||||
weight_column_name,
|
||||
@ -815,9 +864,10 @@ def weighted_sparse_column(sparse_id_column,
|
||||
return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype)
|
||||
|
||||
|
||||
class _OneHotColumn(_FeatureColumn,
|
||||
collections.namedtuple("_OneHotColumn",
|
||||
["sparse_id_column"])):
|
||||
class _OneHotColumn(
|
||||
_FeatureColumn,
|
||||
fc_core._DenseColumn, # pylint: disable=protected-access
|
||||
collections.namedtuple("_OneHotColumn", ["sparse_id_column"])):
|
||||
"""Represents a one-hot column for use in deep networks.
|
||||
|
||||
Args:
|
||||
@ -897,12 +947,31 @@ class _OneHotColumn(_FeatureColumn,
|
||||
return math_ops.reduce_sum(
|
||||
one_hot_id_tensor, reduction_indices=[output_rank - 1])
|
||||
|
||||
@property
|
||||
def _variable_shape(self):
|
||||
return tensor_shape.TensorShape((self.length))
|
||||
|
||||
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
||||
"_EmbeddingColumn",
|
||||
["sparse_id_column", "dimension", "combiner", "initializer",
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
del weight_collections
|
||||
del trainable
|
||||
return inputs.get(self)
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
return self._to_dnn_input_layer(inputs.get(self.sparse_id_column))
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return self.config
|
||||
|
||||
|
||||
class _EmbeddingColumn(
|
||||
_FeatureColumn,
|
||||
fc_core._DenseColumn, # pylint: disable=protected-access
|
||||
collections.namedtuple("_EmbeddingColumn", [
|
||||
"sparse_id_column", "dimension", "combiner", "initializer",
|
||||
"ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
|
||||
"shared_vocab_size", "max_norm", "trainable"])):
|
||||
"shared_vocab_size", "max_norm", "trainable"
|
||||
])):
|
||||
"""Represents an embedding column.
|
||||
|
||||
Args:
|
||||
@ -1027,6 +1096,139 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
||||
raise ValueError("Column {} is not supported in linear models. "
|
||||
"Please use sparse_column.".format(self))
|
||||
|
||||
@property
|
||||
def _variable_shape(self):
|
||||
return tensor_shape.TensorShape((self.dimension))
|
||||
|
||||
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||
return _embeddings_from_arguments(
|
||||
self,
|
||||
self._deep_embedding_lookup_arguments(inputs.get(self)),
|
||||
weight_collections, trainable)
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
return inputs.get(self.sparse_id_column)
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return self.config
|
||||
|
||||
|
||||
def _is_variable(v):
|
||||
"""Returns true if `v` is a variable."""
|
||||
return isinstance(v, (variables.Variable,
|
||||
resource_variable_ops.ResourceVariable))
|
||||
|
||||
|
||||
def _embeddings_from_arguments(column,
|
||||
args,
|
||||
weight_collections,
|
||||
trainable,
|
||||
output_rank=2):
|
||||
"""Returns embeddings for a column based on the computed arguments.
|
||||
|
||||
Args:
|
||||
column: the column name.
|
||||
args: the _DeepEmbeddingLookupArguments for this column.
|
||||
weight_collections: collections to store weights in.
|
||||
trainable: whether these embeddings should be trainable.
|
||||
output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
|
||||
be combined to produce the desired rank.
|
||||
|
||||
Returns:
|
||||
the embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: if not possible to create.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
|
||||
weight_tensor = None
|
||||
if args.weight_tensor is not None:
|
||||
weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# This option is only enabled for scattered_embedding_column.
|
||||
if args.hash_key:
|
||||
embeddings = contrib_variables.model_variable(
|
||||
name="weights",
|
||||
shape=[args.vocab_size],
|
||||
dtype=dtypes.float32,
|
||||
initializer=args.initializer,
|
||||
trainable=(trainable and args.trainable),
|
||||
collections=weight_collections)
|
||||
|
||||
return embedding_ops.scattered_embedding_lookup_sparse(
|
||||
embeddings,
|
||||
input_tensor,
|
||||
args.dimension,
|
||||
hash_key=args.hash_key,
|
||||
combiner=args.combiner,
|
||||
name="lookup")
|
||||
|
||||
if args.shared_embedding_name is not None:
|
||||
shared_embedding_collection_name = (
|
||||
"SHARED_EMBEDDING_COLLECTION_" + args.shared_embedding_name.upper())
|
||||
graph = ops.get_default_graph()
|
||||
shared_embedding_collection = (
|
||||
graph.get_collection_ref(shared_embedding_collection_name))
|
||||
shape = [args.vocab_size, args.dimension]
|
||||
if shared_embedding_collection:
|
||||
if len(shared_embedding_collection) > 1:
|
||||
raise ValueError(
|
||||
"Collection %s can only contain one "
|
||||
"(partitioned) variable." % shared_embedding_collection_name)
|
||||
else:
|
||||
embeddings = shared_embedding_collection[0]
|
||||
if embeddings.get_shape() != shape:
|
||||
raise ValueError(
|
||||
"The embedding variable with name {} already "
|
||||
"exists, but its shape does not match required "
|
||||
"embedding shape here. Please make sure to use "
|
||||
"different shared_embedding_name for different "
|
||||
"shared embeddings.".format(args.shared_embedding_name))
|
||||
else:
|
||||
embeddings = contrib_variables.model_variable(
|
||||
name=args.shared_embedding_name,
|
||||
shape=shape,
|
||||
dtype=dtypes.float32,
|
||||
initializer=args.initializer,
|
||||
trainable=(trainable and args.trainable),
|
||||
collections=weight_collections)
|
||||
graph.add_to_collection(shared_embedding_collection_name, embeddings)
|
||||
else:
|
||||
embeddings = contrib_variables.model_variable(
|
||||
name="weights",
|
||||
shape=[args.vocab_size, args.dimension],
|
||||
dtype=dtypes.float32,
|
||||
initializer=args.initializer,
|
||||
trainable=(trainable and args.trainable),
|
||||
collections=weight_collections)
|
||||
|
||||
if _is_variable(embeddings):
|
||||
embeddings = [embeddings]
|
||||
else:
|
||||
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
_maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings)
|
||||
return embedding_ops.safe_embedding_lookup_sparse(
|
||||
embeddings,
|
||||
input_tensor,
|
||||
sparse_weights=weight_tensor,
|
||||
combiner=args.combiner,
|
||||
name=column.name + "weights",
|
||||
max_norm=args.max_norm)
|
||||
|
||||
|
||||
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
|
||||
if checkpoint_path is not None:
|
||||
path, tensor_name = checkpoint_path
|
||||
weights_to_restore = variable
|
||||
if len(variable) == 1:
|
||||
weights_to_restore = variable[0]
|
||||
checkpoint_utils.init_from_checkpoint(path,
|
||||
{tensor_name: weights_to_restore})
|
||||
|
||||
|
||||
def one_hot_column(sparse_id_column):
|
||||
"""Creates an `_OneHotColumn` for a one-hot or multi-hot repr in a DNN.
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.layers.python.layers import embedding_ops
|
||||
@ -34,118 +33,12 @@ from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _is_variable(v):
|
||||
"""Returns true if `v` is a variable."""
|
||||
return isinstance(v, (variables.Variable,
|
||||
resource_variable_ops.ResourceVariable))
|
||||
|
||||
|
||||
def _embeddings_from_arguments(column,
|
||||
args,
|
||||
weight_collections,
|
||||
trainable,
|
||||
output_rank=2):
|
||||
"""Returns embeddings for a column based on the computed arguments.
|
||||
|
||||
Args:
|
||||
column: the column name.
|
||||
args: the _DeepEmbeddingLookupArguments for this column.
|
||||
weight_collections: collections to store weights in.
|
||||
trainable: whether these embeddings should be trainable.
|
||||
output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
|
||||
be combined to produce the desired rank.
|
||||
|
||||
Returns:
|
||||
the embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: if not possible to create.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
|
||||
weight_tensor = None
|
||||
if args.weight_tensor is not None:
|
||||
weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# This option is only enabled for scattered_embedding_column.
|
||||
if args.hash_key:
|
||||
embeddings = contrib_variables.model_variable(
|
||||
name='weights',
|
||||
shape=[args.vocab_size],
|
||||
dtype=dtypes.float32,
|
||||
initializer=args.initializer,
|
||||
trainable=(trainable and args.trainable),
|
||||
collections=weight_collections)
|
||||
|
||||
return embedding_ops.scattered_embedding_lookup_sparse(
|
||||
embeddings, input_tensor, args.dimension,
|
||||
hash_key=args.hash_key,
|
||||
combiner=args.combiner, name='lookup')
|
||||
|
||||
if args.shared_embedding_name is not None:
|
||||
shared_embedding_collection_name = (
|
||||
'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper())
|
||||
graph = ops.get_default_graph()
|
||||
shared_embedding_collection = (
|
||||
graph.get_collection_ref(shared_embedding_collection_name))
|
||||
shape = [args.vocab_size, args.dimension]
|
||||
if shared_embedding_collection:
|
||||
if len(shared_embedding_collection) > 1:
|
||||
raise ValueError('Collection %s can only contain one '
|
||||
'(partitioned) variable.'
|
||||
% shared_embedding_collection_name)
|
||||
else:
|
||||
embeddings = shared_embedding_collection[0]
|
||||
if embeddings.get_shape() != shape:
|
||||
raise ValueError('The embedding variable with name {} already '
|
||||
'exists, but its shape does not match required '
|
||||
'embedding shape here. Please make sure to use '
|
||||
'different shared_embedding_name for different '
|
||||
'shared embeddings.'.format(
|
||||
args.shared_embedding_name))
|
||||
else:
|
||||
embeddings = contrib_variables.model_variable(
|
||||
name=args.shared_embedding_name,
|
||||
shape=shape,
|
||||
dtype=dtypes.float32,
|
||||
initializer=args.initializer,
|
||||
trainable=(trainable and args.trainable),
|
||||
collections=weight_collections)
|
||||
graph.add_to_collection(shared_embedding_collection_name, embeddings)
|
||||
else:
|
||||
embeddings = contrib_variables.model_variable(
|
||||
name='weights',
|
||||
shape=[args.vocab_size, args.dimension],
|
||||
dtype=dtypes.float32,
|
||||
initializer=args.initializer,
|
||||
trainable=(trainable and args.trainable),
|
||||
collections=weight_collections)
|
||||
|
||||
if _is_variable(embeddings):
|
||||
embeddings = [embeddings]
|
||||
else:
|
||||
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
_maybe_restore_from_checkpoint(
|
||||
column._checkpoint_path(), embeddings)
|
||||
return embedding_ops.safe_embedding_lookup_sparse(
|
||||
embeddings,
|
||||
input_tensor,
|
||||
sparse_weights=weight_tensor,
|
||||
combiner=args.combiner,
|
||||
name=column.name + 'weights',
|
||||
max_norm=args.max_norm)
|
||||
|
||||
|
||||
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
|
||||
"""Reshape the input tensor by the following rule.
|
||||
|
||||
@ -232,7 +125,8 @@ def _input_from_feature_columns(columns_to_tensors,
|
||||
# pylint: disable=protected-access
|
||||
arguments = column._deep_embedding_lookup_arguments(
|
||||
transformed_tensor)
|
||||
output_tensors.append(_embeddings_from_arguments(
|
||||
output_tensors.append(
|
||||
fc._embeddings_from_arguments( # pylint: disable=protected-access
|
||||
column,
|
||||
arguments,
|
||||
weight_collections,
|
||||
@ -393,7 +287,7 @@ def _create_embedding_lookup(column,
|
||||
initializer=embedding_lookup_arguments.initializer,
|
||||
trainable=trainable,
|
||||
collections=weight_collections)
|
||||
if _is_variable(variable):
|
||||
if fc._is_variable(variable): # pylint: disable=protected-access
|
||||
variable = [variable]
|
||||
else:
|
||||
variable = variable._get_variable_list() # pylint: disable=protected-access
|
||||
@ -406,16 +300,6 @@ def _create_embedding_lookup(column,
|
||||
return variable, predictions
|
||||
|
||||
|
||||
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
|
||||
if checkpoint_path is not None:
|
||||
path, tensor_name = checkpoint_path
|
||||
weights_to_restore = variable
|
||||
if len(variable) == 1:
|
||||
weights_to_restore = variable[0]
|
||||
checkpoint_utils.init_from_checkpoint(path,
|
||||
{tensor_name: weights_to_restore})
|
||||
|
||||
|
||||
def _create_joint_embedding_lookup(columns_to_tensors,
|
||||
embedding_lookup_arguments,
|
||||
num_outputs,
|
||||
@ -451,7 +335,7 @@ def _create_joint_embedding_lookup(columns_to_tensors,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=trainable,
|
||||
collections=weight_collections)
|
||||
if _is_variable(variable):
|
||||
if fc._is_variable(variable): # pylint: disable=protected-access
|
||||
variable = [variable]
|
||||
else:
|
||||
variable = variable._get_variable_list() # pylint: disable=protected-access
|
||||
@ -634,7 +518,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
||||
predictions, shape=(-1, num_outputs)))
|
||||
column_to_variable[column] = variable
|
||||
_log_variable(variable)
|
||||
_maybe_restore_from_checkpoint(column._checkpoint_path(), variable)
|
||||
fc._maybe_restore_from_checkpoint(column._checkpoint_path(), variable) # pylint: disable=protected-access
|
||||
# pylint: enable=protected-access
|
||||
predictions_no_bias = math_ops.add_n(output_tensors)
|
||||
bias = contrib_variables.model_variable(
|
||||
@ -827,10 +711,10 @@ def parse_feature_columns_from_sequence_examples(
|
||||
def _log_variable(variable):
|
||||
if isinstance(variable, list):
|
||||
for var in variable:
|
||||
if _is_variable(variable):
|
||||
if fc._is_variable(variable): # pylint: disable=protected-access
|
||||
logging.info('Created variable %s, with device=%s', var.name,
|
||||
var.device)
|
||||
elif _is_variable(variable):
|
||||
elif fc._is_variable(variable): # pylint: disable=protected-access
|
||||
logging.info('Created variable %s, with device=%s', variable.name,
|
||||
variable.device)
|
||||
|
||||
|
@ -597,12 +597,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
"income":
|
||||
constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]),
|
||||
}
|
||||
output = feature_column_ops.input_from_feature_columns(features, [
|
||||
one_hot_column, embedding_column, real_valued_column])
|
||||
columns = [one_hot_column, embedding_column, real_valued_column]
|
||||
output = feature_column_ops.input_from_feature_columns(features, columns)
|
||||
output_core = fc_core.make_input_layer(features, columns)
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
|
||||
|
||||
def testRealValuedColumn(self):
|
||||
real_valued = feature_column.real_valued_column("price")
|
||||
@ -712,11 +715,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
one_hot_column = feature_column.one_hot_column(weighted_ids_column)
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[one_hot_column])
|
||||
output_core = fc_core.make_input_layer(features, [one_hot_column])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
|
||||
output.eval())
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval(), output_core.eval())
|
||||
|
||||
def testOneHotColumnFromSparseColumnWithKeysSucceedsForDNN(self):
|
||||
ids_column = feature_column.sparse_column_with_keys(
|
||||
@ -729,12 +735,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
features = {"ids": ids_tensor}
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[one_hot_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
|
||||
output.eval())
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval(), output_core.eval())
|
||||
|
||||
def testOneHotColumnFromMultivalentSparseColumnWithKeysSucceedsForDNN(self):
|
||||
ids_column = feature_column.sparse_column_with_keys(
|
||||
@ -747,12 +756,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
features = {"ids": ids_tensor}
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[one_hot_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
||||
output.eval())
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval(), output_core.eval())
|
||||
|
||||
def testOneHotColumnFromSparseColumnWithIntegerizedFeaturePassesForDNN(self):
|
||||
ids_column = feature_column.sparse_column_with_integerized_feature(
|
||||
@ -767,10 +779,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
}
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[one_hot_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
||||
output.eval())
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval(), output_core.eval())
|
||||
|
||||
def testOneHotColumnFromSparseColumnWithHashBucketSucceedsForDNN(self):
|
||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("feat", 10)
|
||||
@ -782,10 +797,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
one_hot_sparse = feature_column.one_hot_column(hashed_sparse)
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[one_hot_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual([3, 10], output.eval().shape)
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval(), output_core.eval())
|
||||
|
||||
def testEmbeddingColumnSucceedsForDNN(self):
|
||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||
@ -797,9 +815,12 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[embeded_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [embeded_sparse])
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
self.assertAllEqual(output.eval().shape, [4, 10])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
|
||||
|
||||
def testScatteredEmbeddingColumnSucceedsForDNN(self):
|
||||
wire_tensor = sparse_tensor.SparseTensor(
|
||||
@ -838,12 +859,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
initializer=init_ops.constant_initializer(init_value))
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[embeded_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [embeded_sparse])
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
output_eval = output.eval()
|
||||
self.assertAllEqual(output_eval.shape, [2, 10])
|
||||
self.assertAllClose(output_eval, np.tile(init_value, [2, 10]))
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval(), output_core.eval())
|
||||
|
||||
def testEmbeddingColumnWithMultipleInitializersFails(self):
|
||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||
@ -889,10 +913,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
||||
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
|
||||
output = feature_column_ops.input_from_feature_columns(features,
|
||||
[embeded_sparse])
|
||||
output_core = fc_core.make_input_layer(features, [embeded_sparse])
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
|
||||
|
||||
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
|
||||
"""Same as the previous test, but with integer weights."""
|
||||
@ -1534,9 +1562,12 @@ class WeightedSumTest(test.TestCase):
|
||||
features = {"wire": wire_tensor}
|
||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||
features, [hashed_sparse], num_outputs=5)
|
||||
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||
|
||||
def testSparseIntColumn(self):
|
||||
"""Tests a sparse column with int values."""
|
||||
@ -1549,9 +1580,12 @@ class WeightedSumTest(test.TestCase):
|
||||
features = {"wire": wire_tensor}
|
||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||
features, [hashed_sparse], num_outputs=5)
|
||||
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||
|
||||
def testSparseColumnWithDenseInputTensor(self):
|
||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||
@ -1560,9 +1594,12 @@ class WeightedSumTest(test.TestCase):
|
||||
features = {"wire": wire_tensor}
|
||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||
features, [hashed_sparse], num_outputs=5)
|
||||
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||
|
||||
def testWeightedSparseColumn(self):
|
||||
ids = feature_column.sparse_column_with_keys("ids",
|
||||
@ -1579,10 +1616,13 @@ class WeightedSumTest(test.TestCase):
|
||||
features = {"ids": ids_tensor, "weights": weights_tensor}
|
||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||
features, [weighted_ids], num_outputs=5)
|
||||
logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||
|
||||
def testWeightedSparseColumnWithDenseInputTensor(self):
|
||||
ids = feature_column.sparse_column_with_keys(
|
||||
@ -1594,11 +1634,14 @@ class WeightedSumTest(test.TestCase):
|
||||
features = {"ids": ids_tensor, "weights": weights_tensor}
|
||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||
features, [weighted_ids], num_outputs=5)
|
||||
logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
|
||||
|
||||
with self.test_session():
|
||||
variables_lib.global_variables_initializer().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||
|
||||
def testCrossedColumn(self):
|
||||
a = feature_column.sparse_column_with_hash_bucket(
|
||||
@ -1649,6 +1692,8 @@ class WeightedSumTest(test.TestCase):
|
||||
output, column_to_variable, _ = (
|
||||
feature_column_ops.weighted_sum_from_feature_columns(
|
||||
features, [movies], num_outputs=1))
|
||||
logits_core = fc_core.make_linear_model(features, [movies])
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables_lib.initialize_all_variables().run()
|
||||
lookup_ops.tables_initializer().run()
|
||||
@ -1659,6 +1704,8 @@ class WeightedSumTest(test.TestCase):
|
||||
# score for first example = 0.3 (matrix) + 0.1 (head-on) = 0.4
|
||||
# score for second example = 0.5 (winter sleep)
|
||||
self.assertAllClose(output.eval(), [[0.4], [0.5]])
|
||||
# Cross compatibility: Core builder output should equal to contrib.
|
||||
self.assertAllEqual(output.eval().shape, logits_core.eval().shape)
|
||||
|
||||
def testRealValuedColumnWithMultiDimensions(self):
|
||||
real_valued = feature_column.real_valued_column("price", 2)
|
||||
|
@ -36,7 +36,8 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
|
||||
Xavier Glorot and Yoshua Bengio (2010):
|
||||
[Understanding the difficulty of training deep feedforward neural
|
||||
networks. International conference on artificial intelligence and
|
||||
statistics.](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.207.2059&rep=rep1&type=pdf)
|
||||
statistics.](
|
||||
http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
|
||||
|
||||
This initializer is designed to keep the scale of the gradients roughly the
|
||||
same in all layers. In uniform distribution this ends up being the range:
|
||||
|
@ -102,9 +102,10 @@ def _linear_learning_rate(num_linear_feature_columns):
|
||||
def _add_hidden_layer_summary(value, tag):
|
||||
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
|
||||
summary.histogram("%s/activation" % tag, value)
|
||||
|
||||
|
||||
def _add_layer_summary(value, tag):
|
||||
summary.scalar("%s/fraction_of_zero_values" % tag,
|
||||
nn.zero_fraction(value))
|
||||
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
|
||||
summary.histogram("%s/activation" % tag, value)
|
||||
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib.framework.python.framework import deprecated
|
||||
from tensorflow.contrib.layers.python.layers import optimizers
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import rnn as rnn_cell
|
||||
from tensorflow.contrib.framework.python.framework import deprecated
|
||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||
from tensorflow.contrib.layers.python.layers import optimizers
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
|
@ -455,6 +455,7 @@ class LegacyConstructorTest(test.TestCase):
|
||||
return {'inputs': inputs}, labels
|
||||
return input_fn
|
||||
|
||||
|
||||
# TODO(jtbates): move all tests below to a benchmark test.
|
||||
class StateSavingRNNEstimatorLearningTest(test.TestCase):
|
||||
"""Learning tests for state saving RNN Estimators."""
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import estimator as estimator_lib
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import experiment
|
||||
from tensorflow.contrib.learn.python.learn import run_config
|
||||
@ -38,6 +39,7 @@ from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class SheepCounter(object):
|
||||
@ -119,6 +121,15 @@ class TestBaseEstimator(object):
|
||||
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
|
||||
|
||||
|
||||
def _check_method_supports_args(method, kwargs):
|
||||
"""Checks that the given method supports the given args."""
|
||||
supported_args = tuple(tf_inspect.getargspec(method).args)
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in supported_args:
|
||||
raise ValueError(
|
||||
'Argument `{}` is not supported in method {}.'.format(kwarg, method))
|
||||
|
||||
|
||||
class TestEstimator(
|
||||
TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
||||
|
||||
@ -126,9 +137,12 @@ class TestEstimator(
|
||||
super(TestEstimator, self).__init__(config, max_evals, eval_dict)
|
||||
tf_logging.info('Create Estimator')
|
||||
|
||||
def evaluate(self, **kwargs):
|
||||
_check_method_supports_args(evaluable.Evaluable.evaluate, kwargs)
|
||||
return super(TestEstimator, self).evaluate(**kwargs)
|
||||
|
||||
def fit(self, **kwargs):
|
||||
if 'hooks' in kwargs:
|
||||
raise ValueError('`hooks` is defined in core Estimator')
|
||||
_check_method_supports_args(trainable.Trainable.fit, kwargs)
|
||||
if 'monitors' in kwargs:
|
||||
self.monitors = kwargs['monitors']
|
||||
return super(TestEstimator, self).train(**kwargs)
|
||||
@ -136,6 +150,13 @@ class TestEstimator(
|
||||
def train(self, **kwargs):
|
||||
raise ValueError('`train` is not defined in Estimator.')
|
||||
|
||||
def export_savedmodel(
|
||||
self, export_dir_base, serving_input_fn, **kwargs):
|
||||
_check_method_supports_args(
|
||||
estimator_lib.Estimator.export_savedmodel, kwargs)
|
||||
return super(TestEstimator, self).export_savedmodel(
|
||||
export_dir_base, serving_input_fn, **kwargs)
|
||||
|
||||
|
||||
class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
|
||||
|
||||
@ -144,17 +165,22 @@ class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
|
||||
tf_logging.info('Create Core Estimator')
|
||||
|
||||
def evaluate(self, **kwargs):
|
||||
if 'eval_metrics' in kwargs:
|
||||
raise ValueError('`eval_metrics` is not defined in core Estimator')
|
||||
_check_method_supports_args(core_estimator.Estimator.evaluate, kwargs)
|
||||
return super(TestCoreEstimator, self).evaluate(**kwargs)
|
||||
|
||||
def train(self, **kwargs):
|
||||
if 'monitors' in kwargs:
|
||||
raise ValueError('`monitors` is not defined in core Estimator')
|
||||
_check_method_supports_args(core_estimator.Estimator.train, kwargs)
|
||||
if 'hooks' in kwargs:
|
||||
self.monitors = kwargs['hooks']
|
||||
return super(TestCoreEstimator, self).train(**kwargs)
|
||||
|
||||
def export_savedmodel(
|
||||
self, export_dir_base, serving_input_receiver_fn, **kwargs):
|
||||
_check_method_supports_args(
|
||||
core_estimator.Estimator.export_savedmodel, kwargs)
|
||||
return super(TestCoreEstimator, self).export_savedmodel(
|
||||
export_dir_base, serving_input_receiver_fn, **kwargs)
|
||||
|
||||
|
||||
class _NoopHook(session_run_hook.SessionRunHook):
|
||||
pass
|
||||
@ -184,6 +210,23 @@ class ExperimentTest(test.TestCase):
|
||||
eval_input_fn='eval_input',
|
||||
eval_metrics='eval_metrics')
|
||||
|
||||
def test_default_output_alternative_key_core_estimator(self):
|
||||
est = TestCoreEstimator()
|
||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||
est,
|
||||
default_output_alternative_key='export_key',
|
||||
exports_to_keep=None)
|
||||
ex = experiment.Experiment(
|
||||
est,
|
||||
train_input_fn='train_input',
|
||||
eval_input_fn='eval_input',
|
||||
train_steps=100,
|
||||
eval_steps=100,
|
||||
export_strategies=export_strategy)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'default_output_alternative_key is not supported'):
|
||||
ex.train_and_evaluate()
|
||||
|
||||
def test_train(self):
|
||||
for est in self._estimators_for_tests():
|
||||
eval_metrics = 'eval_metrics' if not isinstance(
|
||||
@ -508,7 +551,9 @@ class ExperimentTest(test.TestCase):
|
||||
eval_metrics = 'eval_metrics' if not isinstance(
|
||||
est, core_estimator.Estimator) else None
|
||||
export_strategy_1 = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input_1', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_1',
|
||||
exports_to_keep=None)
|
||||
|
||||
ex = experiment.Experiment(
|
||||
est,
|
||||
@ -531,9 +576,13 @@ class ExperimentTest(test.TestCase):
|
||||
# After reset with list, the count should increase with the number of
|
||||
# items.
|
||||
export_strategy_2 = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input_2', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_2',
|
||||
exports_to_keep=None)
|
||||
export_strategy_3 = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input_3', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_3',
|
||||
exports_to_keep=None)
|
||||
|
||||
old_es = ex.reset_export_strategies(
|
||||
[export_strategy_2, export_strategy_3])
|
||||
@ -547,7 +596,9 @@ class ExperimentTest(test.TestCase):
|
||||
est, core_estimator.Estimator) else None
|
||||
noop_hook = _NoopHook()
|
||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||
exports_to_keep=None)
|
||||
ex = experiment.Experiment(
|
||||
est,
|
||||
train_input_fn='train_input',
|
||||
@ -625,7 +676,9 @@ class ExperimentTest(test.TestCase):
|
||||
est, core_estimator.Estimator) else None
|
||||
noop_hook = _NoopHook()
|
||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||
exports_to_keep=None)
|
||||
ex = experiment.Experiment(
|
||||
est,
|
||||
train_input_fn='train_input',
|
||||
@ -646,7 +699,9 @@ class ExperimentTest(test.TestCase):
|
||||
eval_metrics = 'eval_metrics' if not isinstance(
|
||||
est, core_estimator.Estimator) else None
|
||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||
exports_to_keep=None)
|
||||
ex = experiment.Experiment(
|
||||
est,
|
||||
train_input_fn='train_input',
|
||||
@ -796,7 +851,9 @@ class ExperimentTest(test.TestCase):
|
||||
def test_test(self):
|
||||
for est in self._estimators_for_tests():
|
||||
exp_strategy = saved_model_export_utils.make_export_strategy(
|
||||
est, 'export_input', exports_to_keep=None)
|
||||
est,
|
||||
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||
exports_to_keep=None)
|
||||
ex = experiment.Experiment(
|
||||
est,
|
||||
train_input_fn='train_input',
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.learn.python.learn.utils import gc
|
||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||
from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.platform import gfile
|
||||
@ -352,7 +353,8 @@ def make_export_strategy(serving_input_fn,
|
||||
`InputFnOps`.
|
||||
default_output_alternative_key: the name of the head to serve when an
|
||||
incoming serving request does not explicitly request a specific head.
|
||||
Not needed for single-headed models.
|
||||
Must be `None` if the estimator inherits from ${tf.estimator.Estimator}
|
||||
or for single-headed models.
|
||||
assets_extra: A dict specifying how to populate the assets.extra directory
|
||||
within the exported SavedModel. Each key should give the destination
|
||||
path (including the filename) relative to the assets.extra directory.
|
||||
@ -384,7 +386,23 @@ def make_export_strategy(serving_input_fn,
|
||||
|
||||
Returns:
|
||||
The string path to the exported directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
|
||||
and `default_output_alternative_key` was specified.
|
||||
"""
|
||||
if isinstance(estimator, core_estimator.Estimator):
|
||||
if default_output_alternative_key is not None:
|
||||
raise ValueError(
|
||||
'default_output_alternative_key is not supported in core '
|
||||
'Estimator. Given: {}'.format(default_output_alternative_key))
|
||||
export_result = estimator.export_savedmodel(
|
||||
export_dir_base,
|
||||
serving_input_fn,
|
||||
assets_extra=assets_extra,
|
||||
as_text=as_text,
|
||||
checkpoint_path=checkpoint_path)
|
||||
else:
|
||||
export_result = estimator.export_savedmodel(
|
||||
export_dir_base,
|
||||
serving_input_fn,
|
||||
|
@ -1,9 +1,9 @@
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
|
||||
py_library(
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""##Signal ops.
|
||||
|
||||
"""
|
||||
@@frames
|
||||
"""
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Signal ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Signal ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""General shape ops for frames."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -23,8 +24,10 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
def frames(signal, frame_length, frame_step, name=None):
|
||||
"""Frame a signal into overlapping frames.
|
||||
|
||||
May be used in front of spectral functions.
|
||||
|
||||
For example:
|
||||
@ -44,6 +47,9 @@ def frames(signal, frame_length, frame_step, name=None):
|
||||
|
||||
Returns:
|
||||
A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
|
||||
|
||||
Raises:
|
||||
ValueError: if signal does not have rank 2.
|
||||
"""
|
||||
with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
|
||||
signal = ops.convert_to_tensor(signal, name="signal")
|
||||
@ -61,8 +67,8 @@ def frames(signal, frame_length, frame_step, name=None):
|
||||
num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
|
||||
|
||||
pad_length = (num_frames - 1) * frame_step + frame_length
|
||||
pad_signal = array_ops.pad(
|
||||
signal, [[0, 0], [0, pad_length - signal_length]])
|
||||
pad_signal = array_ops.pad(signal, [[0, 0], [0,
|
||||
pad_length - signal_length]])
|
||||
|
||||
indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
|
||||
indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
|
||||
@ -73,9 +79,9 @@ def frames(signal, frame_length, frame_step, name=None):
|
||||
|
||||
indices = indices_frames + indices_steps
|
||||
|
||||
# TODO(Androbin): remove `transpose` when `gather` gets `axis` support
|
||||
# TODO(androbin): remove `transpose` when `gather` gets `axis` support
|
||||
pad_signal = array_ops.transpose(pad_signal)
|
||||
frames = array_ops.gather(pad_signal, indices)
|
||||
frames = array_ops.transpose(frames, perm=[2, 0, 1])
|
||||
signal_frames = array_ops.gather(pad_signal, indices)
|
||||
signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
|
||||
|
||||
return frames
|
||||
return signal_frames
|
||||
|
@ -97,6 +97,29 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pprof_profiler",
|
||||
srcs = ["pprof_profiler.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["@pprof_profile_proto//:pprof_proto_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "pprof_profiler_test",
|
||||
srcs = ["pprof_profiler_test.py"],
|
||||
main = "pprof_profiler_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"], # TODO(annarev): get it working with pip.
|
||||
deps = [
|
||||
":pprof_profiler",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@pprof_profile_proto//:pprof_proto_py",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
|
445
tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
Normal file
445
tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
Normal file
@ -0,0 +1,445 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Profiler for TensorFlow models that outputs data in pprof format.
|
||||
|
||||
See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
|
||||
profile format.
|
||||
The following needs to be set for profiler to work:
|
||||
* trace_level needs to be set to FULL_TRACE
|
||||
* run_metadata object should be passed in to session.run call
|
||||
|
||||
Sample usage:
|
||||
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
with tf.Session as sess:
|
||||
...
|
||||
sess.run(computation, run_metadata=run_metadata, options=options)
|
||||
pprof_profiler.profile(sess.graph, run_metadata, output_dir)
|
||||
|
||||
|
||||
The code above would output a pprof profile to separate output_dir/.*.pb.gz
|
||||
file for each device. These files can be passed to pprof for formatting.
|
||||
For e.g.:
|
||||
pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
from collections import namedtuple
|
||||
import gzip
|
||||
import os
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pprof import profile_pb2
|
||||
|
||||
|
||||
if sys.version_info < (3,):
|
||||
maketrans = string.maketrans
|
||||
else:
|
||||
maketrans = str.maketrans
|
||||
|
||||
|
||||
ProfileDatum = namedtuple('ProfileDatum', [
|
||||
'node_exec_stats', 'op_type', 'traceback'])
|
||||
|
||||
|
||||
class StringTable(object):
|
||||
"""Keeps track of strings to add to string_table in pprof proto."""
|
||||
|
||||
def __init__(self):
|
||||
# Pprof requires first entry in string_table to be ''.
|
||||
self._string_table = ['']
|
||||
self._string_to_index = {'': 0}
|
||||
|
||||
def index_of(self, value_str):
|
||||
"""Get index of value_str in the string table.
|
||||
|
||||
If value_str is not in the string table, we will add it at the end
|
||||
and then return the new index.
|
||||
Args:
|
||||
value_str: (string) Value to lookup/add in/to the string table.
|
||||
|
||||
Returns:
|
||||
Index of value_str in the string table.
|
||||
"""
|
||||
if value_str is None:
|
||||
value_str = ''
|
||||
if value_str in self._string_to_index:
|
||||
return self._string_to_index[value_str]
|
||||
index = len(self._string_table)
|
||||
self._string_table.append(value_str)
|
||||
self._string_to_index[value_str] = index
|
||||
return index
|
||||
|
||||
def next_index(self):
|
||||
"""Gets index that would be assigned to the next added string.
|
||||
|
||||
Returns:
|
||||
Index of the next string if it was added.
|
||||
"""
|
||||
return len(self._string_table)
|
||||
|
||||
def string_table(self):
|
||||
"""Returns a list of strings to store in pprof's string_table."""
|
||||
return self._string_table
|
||||
|
||||
|
||||
class Functions(object):
|
||||
"""Keeps track of `Function` protos for pprof profile."""
|
||||
|
||||
def __init__(self, string_table):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
string_table: A `StringTable` object.
|
||||
"""
|
||||
self._string_table = string_table
|
||||
# Maps tuples in the form (file_path, function_name, start_line_number)
|
||||
# to `Function` protos.
|
||||
self._function_key_to_function = {}
|
||||
|
||||
def index_of(self, file_path, function_name, function_start_line):
|
||||
"""Returns index of the function, adding the function if needed.
|
||||
|
||||
Args:
|
||||
file_path: (string) Path to file where the function is defined.
|
||||
function_name: (string) Function name.
|
||||
function_start_line: (integer) Start line number of function definition.
|
||||
|
||||
Returns:
|
||||
Function index.
|
||||
"""
|
||||
function_key = (file_path, function_name, function_start_line)
|
||||
if function_key in self._function_key_to_function:
|
||||
return self._function_key_to_function[function_key].id
|
||||
else:
|
||||
# Function indexes should start from 1
|
||||
function_index = len(self._function_key_to_function) + 1
|
||||
function = profile_pb2.Function()
|
||||
function.id = function_index
|
||||
function.name = self._string_table.index_of(function_name)
|
||||
function.filename = self._string_table.index_of(file_path)
|
||||
function.start_line = function_start_line
|
||||
self._function_key_to_function[function_key] = function
|
||||
return function_index
|
||||
|
||||
def function_protos(self):
|
||||
"""Returns list of `profile_pb2.Function` protos."""
|
||||
return self._function_key_to_function.values()
|
||||
|
||||
|
||||
class Locations(object):
|
||||
"""Keeps track of `Location` protos for pprof profile.
|
||||
|
||||
`Locations` store information about function call locations.
|
||||
"""
|
||||
|
||||
def __init__(self, functions):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
functions: A `Functions` object.
|
||||
"""
|
||||
self._functions = functions
|
||||
# Maps tuples in the form (file_path, called_function_name, line_number)
|
||||
# to `Location` protos.
|
||||
self._location_key_to_location = {}
|
||||
|
||||
def index_of(
|
||||
self, file_path, line_number, called_function_name, called_file_path,
|
||||
called_function_start_line):
|
||||
"""Returns index of the location, adding the location if needed.
|
||||
|
||||
Args:
|
||||
file_path: (string) Path to file that makes the call.
|
||||
line_number: (integer) Call line number.
|
||||
called_function_name: (string) Function name of the function called at
|
||||
`file_path` and `line_number`.
|
||||
called_file_path: (string) Path to file where the called function is
|
||||
defined.
|
||||
called_function_start_line: (integer) Start line number of called
|
||||
function definition in `called_file_path` file.
|
||||
|
||||
Returns:
|
||||
Index of location.
|
||||
"""
|
||||
location_key = (file_path, called_function_name, line_number)
|
||||
if location_key in self._location_key_to_location:
|
||||
location = self._location_key_to_location[location_key]
|
||||
return location.id
|
||||
else:
|
||||
# Location indexes should start from 1
|
||||
location_index = len(self._location_key_to_location) + 1
|
||||
location = profile_pb2.Location()
|
||||
location.id = location_index
|
||||
self._location_key_to_location[location_key] = location
|
||||
|
||||
line = location.line.add()
|
||||
line.function_id = self._functions.index_of(
|
||||
called_file_path, called_function_name, called_function_start_line)
|
||||
line.line = line_number
|
||||
return location_index
|
||||
|
||||
def location_protos(self):
|
||||
"""Returns list of `profile_pb2.Location` protos."""
|
||||
return self._location_key_to_location.values()
|
||||
|
||||
|
||||
class Samples(object):
|
||||
"""Keeps track of `Sample` protos for pprof profile.
|
||||
|
||||
Samples store the following statistics in order:
|
||||
count, all_time, op_time
|
||||
"""
|
||||
|
||||
def __init__(self, string_table):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
string_table: A `StringTable` object.
|
||||
"""
|
||||
self._string_table = string_table
|
||||
# TODO(annarev): figure out if location is unique for each node name.
|
||||
# If not, also key this dictionary based on location ids.
|
||||
self._node_name_to_sample = {}
|
||||
|
||||
def add(self, datum, location_ids):
|
||||
"""Adds a sample data point.
|
||||
|
||||
Args:
|
||||
datum: `ProfileDatum` to add a sample for.
|
||||
location_ids: List of numberic location ids for this
|
||||
sample.
|
||||
"""
|
||||
node_name = datum.node_exec_stats.node_name
|
||||
if node_name in self._node_name_to_sample:
|
||||
sample = self._node_name_to_sample[node_name]
|
||||
sample.location_id.extend(location_ids)
|
||||
else:
|
||||
sample = profile_pb2.Sample()
|
||||
# Sample stores 3 values: count, all_time, op_time
|
||||
sample.value.extend([0, 0, 0])
|
||||
|
||||
label = sample.label.add()
|
||||
label.key = self._string_table.index_of('node_name')
|
||||
label.str = self._string_table.index_of(node_name)
|
||||
label = sample.label.add()
|
||||
label.key = self._string_table.index_of('op_type')
|
||||
label.str = self._string_table.index_of(datum.op_type)
|
||||
self._node_name_to_sample[node_name] = sample
|
||||
sample.value[0] += 1
|
||||
sample.value[1] += datum.node_exec_stats.all_end_rel_micros
|
||||
sample.value[2] += (
|
||||
datum.node_exec_stats.op_end_rel_micros -
|
||||
datum.node_exec_stats.op_start_rel_micros)
|
||||
|
||||
def get_sample_protos(self):
|
||||
"""Returns list of `Sample` protos for pprof profile."""
|
||||
return self._node_name_to_sample.values()
|
||||
|
||||
|
||||
class PprofProfiler(object):
|
||||
"""Creates profiles in pprof format."""
|
||||
|
||||
def __init__(self, graph, run_metadata):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
graph: A `Graph` instance.
|
||||
run_metadata: A list of `RunMetadata` objects.
|
||||
"""
|
||||
self._graph = graph
|
||||
self._run_metadata = run_metadata
|
||||
self._string_table = StringTable()
|
||||
self._functions = Functions(self._string_table)
|
||||
self._locations = Locations(self._functions)
|
||||
|
||||
def profile(self):
|
||||
"""Generates pprof profiles.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping from device name to proto in `profile_pb2.Profile`
|
||||
format.
|
||||
"""
|
||||
profiles = {}
|
||||
data_generator_func = self._get_profile_data_generator()
|
||||
for device_index, device_stats in enumerate(
|
||||
self._run_metadata.step_stats.dev_stats):
|
||||
# Create profile
|
||||
pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
|
||||
if not pprof_proto.sample:
|
||||
print(
|
||||
'Not enough data to create profile for device %s. Did you pass '
|
||||
'RunMetadata to session.run call?' % device_stats.device)
|
||||
continue
|
||||
# Add device name comment
|
||||
device_count = len(self._run_metadata.step_stats.dev_stats)
|
||||
device_description = (
|
||||
'Device %d of %d: %s' %
|
||||
(device_index + 1, device_count, device_stats.device))
|
||||
device_description_str_index = self._string_table.next_index()
|
||||
pprof_proto.string_table.append(device_description)
|
||||
pprof_proto.comment.append(device_description_str_index)
|
||||
profiles[device_stats.device] = pprof_proto
|
||||
return profiles
|
||||
|
||||
def _get_pprof_proto(self, profile_datum_generator):
|
||||
"""Returns profile data in pprof proto format.
|
||||
|
||||
Args:
|
||||
profile_datum_generator: Generator outputting `ProfileDatum` objects.
|
||||
|
||||
Returns:
|
||||
A proto in pprof format.
|
||||
"""
|
||||
pprof_profile = profile_pb2.Profile()
|
||||
samples = Samples(self._string_table)
|
||||
|
||||
for datum in profile_datum_generator:
|
||||
if not datum.traceback:
|
||||
continue
|
||||
|
||||
stack_frame = datum.traceback[-1]
|
||||
after_apply_op = False
|
||||
location_ids = []
|
||||
|
||||
# We add locations from stack trace in bottom-up order.
|
||||
for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
|
||||
prev_stack_frame = stack_frame
|
||||
stack_frame = datum.traceback[stack_frame_index]
|
||||
|
||||
# Call at current frame calls function at previous frame.
|
||||
prev_file_path = prev_stack_frame[0]
|
||||
prev_function = prev_stack_frame[2]
|
||||
prev_function_start_line = prev_stack_frame[4]
|
||||
curr_file_path = stack_frame[0]
|
||||
curr_line_number = stack_frame[1]
|
||||
|
||||
# Skip all calls up to apply_op since they are the same for all ops.
|
||||
if not after_apply_op:
|
||||
if prev_function == 'apply_op':
|
||||
after_apply_op = True
|
||||
continue
|
||||
location_index = self._locations.index_of(
|
||||
curr_file_path, curr_line_number,
|
||||
prev_function, prev_file_path, prev_function_start_line)
|
||||
location_ids.append(location_index)
|
||||
samples.add(datum, location_ids)
|
||||
|
||||
sample_type_description = 'count'
|
||||
sample_type = pprof_profile.sample_type.add()
|
||||
sample_type.type = self._string_table.index_of(sample_type_description)
|
||||
sample_type.unit = self._string_table.index_of('count')
|
||||
sample_type_description = 'all_time'
|
||||
sample_type = pprof_profile.sample_type.add()
|
||||
sample_type.type = self._string_table.index_of(sample_type_description)
|
||||
sample_type.unit = self._string_table.index_of('nanoseconds')
|
||||
sample_type_description = 'op_time'
|
||||
sample_type = pprof_profile.sample_type.add()
|
||||
sample_type.type = self._string_table.index_of(sample_type_description)
|
||||
sample_type.unit = self._string_table.index_of('nanoseconds')
|
||||
|
||||
pprof_profile.string_table.extend(self._string_table.string_table())
|
||||
pprof_profile.sample.extend(samples.get_sample_protos())
|
||||
pprof_profile.function.extend(self._functions.function_protos())
|
||||
pprof_profile.location.extend(self._locations.location_protos())
|
||||
return pprof_profile
|
||||
|
||||
def _get_profile_data_generator(self):
|
||||
"""Get function that generates `ProfileDatum` objects.
|
||||
|
||||
Returns:
|
||||
A function that generates `ProfileDatum` objects.
|
||||
"""
|
||||
node_to_traceback = defaultdict(list)
|
||||
node_to_op_type = defaultdict(str)
|
||||
for op in self._graph.get_operations():
|
||||
node_to_traceback[op.name] = op.traceback_with_start_lines
|
||||
node_to_op_type[op.name] = op.type
|
||||
|
||||
def profile_data_generator(device_step_stats):
|
||||
for node_stats in device_step_stats.node_stats:
|
||||
if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
|
||||
continue
|
||||
yield ProfileDatum(
|
||||
node_stats,
|
||||
node_to_op_type[node_stats.node_name],
|
||||
node_to_traceback[node_stats.node_name])
|
||||
|
||||
return profile_data_generator
|
||||
|
||||
|
||||
def get_profiles(graph, run_metadata):
|
||||
"""Generate profiles in pprof format.
|
||||
|
||||
See https://github.com/google/pprof/blob/master/proto/profile.proto
|
||||
for pprof proto format.
|
||||
|
||||
Args:
|
||||
graph: A `Graph` object.
|
||||
run_metadata: A `RunMetadata` proto.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping from device name to pprof proto for that device.
|
||||
"""
|
||||
return PprofProfiler(graph, run_metadata).profile()
|
||||
|
||||
|
||||
def profile(graph, run_metadata, output_dir=None):
|
||||
"""Generate profiles in pprof format.
|
||||
|
||||
See https://github.com/google/pprof/blob/master/proto/profile.proto
|
||||
for pprof proto format.
|
||||
|
||||
Args:
|
||||
graph: A `Graph` object.
|
||||
run_metadata: A `RunMetadata` proto.
|
||||
output_dir: (string) Directory to output pprof profile to.
|
||||
Profile files for each device will be stored in compressed
|
||||
serialized proto format. If output_dir is None, profile protos
|
||||
will be printed to stdout instead.
|
||||
|
||||
Returns:
|
||||
List of output files created by this profile call.
|
||||
(Note: this list will be empty if output_dir is None)
|
||||
"""
|
||||
profiles = get_profiles(graph, run_metadata)
|
||||
output_file_template = None
|
||||
if output_dir:
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
time_suffix = time.strftime('%Y%m%d%H%M%S')
|
||||
output_file_template = os.path.join(
|
||||
output_dir, '%s_' + time_suffix + '.pb.gz')
|
||||
|
||||
profile_files = []
|
||||
for device, pprof_proto in profiles.items():
|
||||
if output_file_template is None:
|
||||
print('No output directory specified, printing to stdout instead.')
|
||||
print(pprof_proto)
|
||||
else:
|
||||
device_name = str(device).strip('/').translate(
|
||||
maketrans('/:', '__'))
|
||||
profile_file = output_file_template % device_name
|
||||
profile_files.append(profile_file)
|
||||
with gzip.open(profile_file, 'w') as output_file:
|
||||
print('Writing profile to %s...' % profile_file)
|
||||
output_file.write(pprof_proto.SerializeToString())
|
||||
return profile_files
|
@ -0,0 +1,164 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for pprof_profiler."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
|
||||
from pprof import profile_pb2
|
||||
from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler
|
||||
from tensorflow.core.framework import step_stats_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PprofProfilerTest(test.TestCase):
|
||||
|
||||
def testDataEmpty(self):
|
||||
output_dir = test.get_temp_dir()
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
graph = test.mock.MagicMock()
|
||||
graph.get_operations.return_value = []
|
||||
|
||||
profiles = pprof_profiler.get_profiles(graph, run_metadata)
|
||||
self.assertEquals(0, len(profiles))
|
||||
profile_files = pprof_profiler.profile(
|
||||
graph, run_metadata, output_dir)
|
||||
self.assertEquals(0, len(profile_files))
|
||||
|
||||
def testRunMetadataEmpty(self):
|
||||
output_dir = test.get_temp_dir()
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
graph = test.mock.MagicMock()
|
||||
op1 = test.mock.MagicMock()
|
||||
op1.name = 'Add/123'
|
||||
op1.traceback = [('a/b/file1', 10, 'some_var')]
|
||||
op1.type = 'add'
|
||||
graph.get_operations.return_value = [op1]
|
||||
|
||||
profiles = pprof_profiler.get_profiles(graph, run_metadata)
|
||||
self.assertEquals(0, len(profiles))
|
||||
profile_files = pprof_profiler.profile(
|
||||
graph, run_metadata, output_dir)
|
||||
self.assertEquals(0, len(profile_files))
|
||||
|
||||
def testValidProfile(self):
|
||||
output_dir = test.get_temp_dir()
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
|
||||
node1 = step_stats_pb2.NodeExecStats(
|
||||
node_name='Add/123',
|
||||
op_start_rel_micros=3,
|
||||
op_end_rel_micros=5,
|
||||
all_end_rel_micros=4)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
device1 = run_metadata.step_stats.dev_stats.add()
|
||||
device1.device = 'deviceA'
|
||||
device1.node_stats.extend([node1])
|
||||
|
||||
graph = test.mock.MagicMock()
|
||||
op1 = test.mock.MagicMock()
|
||||
op1.name = 'Add/123'
|
||||
op1.traceback = [
|
||||
('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
|
||||
op1.type = 'add'
|
||||
graph.get_operations.return_value = [op1]
|
||||
|
||||
expected_proto = """sample_type {
|
||||
type: 5
|
||||
unit: 5
|
||||
}
|
||||
sample_type {
|
||||
type: 6
|
||||
unit: 7
|
||||
}
|
||||
sample_type {
|
||||
type: 8
|
||||
unit: 7
|
||||
}
|
||||
sample {
|
||||
value: 1
|
||||
value: 4
|
||||
value: 2
|
||||
label {
|
||||
key: 1
|
||||
str: 2
|
||||
}
|
||||
label {
|
||||
key: 3
|
||||
str: 4
|
||||
}
|
||||
}
|
||||
string_table: ""
|
||||
string_table: "node_name"
|
||||
string_table: "Add/123"
|
||||
string_table: "op_type"
|
||||
string_table: "add"
|
||||
string_table: "count"
|
||||
string_table: "all_time"
|
||||
string_table: "nanoseconds"
|
||||
string_table: "op_time"
|
||||
string_table: "Device 1 of 1: deviceA"
|
||||
comment: 9
|
||||
"""
|
||||
# Test with protos
|
||||
profiles = pprof_profiler.get_profiles(graph, run_metadata)
|
||||
self.assertEquals(1, len(profiles))
|
||||
self.assertTrue('deviceA' in profiles)
|
||||
self.assertEquals(expected_proto, str(profiles['deviceA']))
|
||||
# Test with files
|
||||
profile_files = pprof_profiler.profile(
|
||||
graph, run_metadata, output_dir)
|
||||
self.assertEquals(1, len(profile_files))
|
||||
with gzip.open(profile_files[0]) as profile_file:
|
||||
profile_contents = profile_file.read()
|
||||
profile = profile_pb2.Profile()
|
||||
profile.ParseFromString(profile_contents)
|
||||
self.assertEquals(expected_proto, str(profile))
|
||||
|
||||
def testProfileWithWhileLoop(self):
|
||||
options = config_pb2.RunOptions()
|
||||
options.trace_level = config_pb2.RunOptions.FULL_TRACE
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
|
||||
num_iters = 5
|
||||
with self.test_session() as sess:
|
||||
i = constant_op.constant(0)
|
||||
c = lambda i: math_ops.less(i, num_iters)
|
||||
b = lambda i: math_ops.add(i, 1)
|
||||
r = control_flow_ops.while_loop(c, b, [i])
|
||||
sess.run(r, options=options, run_metadata=run_metadata)
|
||||
profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
|
||||
self.assertEquals(1, len(profiles))
|
||||
profile = next(iter(profiles.values()))
|
||||
add_samples = [] # Samples for the while/Add node
|
||||
for sample in profile.sample:
|
||||
if profile.string_table[sample.label[0].str] == 'while/Add':
|
||||
add_samples.append(sample)
|
||||
# Values for same nodes are aggregated.
|
||||
self.assertEquals(1, len(add_samples))
|
||||
# Value of "count" should be equal to number of iterations.
|
||||
self.assertEquals(num_iters, add_samples[0].value[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -272,7 +272,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
self_.qpn = qp_->qp_num;
|
||||
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
|
||||
union ibv_gid gid;
|
||||
CHECK(!ibv_query_gid(adapter_->context_, (uint8_t) 1, 0, &gid)) << "Query gid";
|
||||
CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
|
||||
<< "Query gid";
|
||||
self_.snp = gid.global.subnet_prefix;
|
||||
self_.iid = gid.global.interface_id;
|
||||
}
|
||||
|
@ -248,8 +248,8 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
|
||||
tdata.size(), do_nothing);
|
||||
slices[1] = ::grpc::Slice(s1, ::grpc::Slice::STEAL_REF);
|
||||
|
||||
gpr_slice s2 = gpr_slice_new(const_cast<TensorBuffer*>(buf),
|
||||
0, unref_tensorbuffer);
|
||||
gpr_slice s2 =
|
||||
gpr_slice_new(const_cast<TensorBuffer*>(buf), 0, unref_tensorbuffer);
|
||||
slices[2] = ::grpc::Slice(s2, ::grpc::Slice::STEAL_REF);
|
||||
num_slices += 2;
|
||||
}
|
||||
|
@ -135,6 +135,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "virtual_placer",
|
||||
srcs = ["virtual_placer.cc"],
|
||||
hdrs = ["virtual_placer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":op_performance_data_cc",
|
||||
":utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:devices",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "virtual_scheduler",
|
||||
srcs = ["virtual_scheduler.cc"],
|
||||
@ -194,3 +210,24 @@ cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "analytical_cost_estimator",
|
||||
srcs = ["analytical_cost_estimator.cc"],
|
||||
hdrs = ["analytical_cost_estimator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
":graph_properties",
|
||||
":op_level_cost_estimator",
|
||||
":op_performance_data_cc",
|
||||
":utils",
|
||||
":virtual_placer",
|
||||
":virtual_scheduler",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
128
tensorflow/core/grappler/costs/analytical_cost_estimator.cc
Normal file
128
tensorflow/core/grappler/costs/analytical_cost_estimator.cc
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
|
||||
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
|
||||
bool use_static_shapes)
|
||||
: cluster_(cluster), use_static_shapes_(use_static_shapes) {}
|
||||
|
||||
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
||||
item_ = item;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||
CostGraphDef* cost_graph,
|
||||
Costs* costs) const {
|
||||
GrapplerItem item = item_;
|
||||
item.graph = optimized_graph;
|
||||
GraphProperties properties(item);
|
||||
Status status;
|
||||
if (use_static_shapes_) {
|
||||
status = properties.InferStatically();
|
||||
} else {
|
||||
status = properties.InferDynamically(cluster_);
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
costs->execution_time = Costs::Duration::max();
|
||||
return status;
|
||||
}
|
||||
|
||||
std::unordered_map<string, CostGraphDef::Node*> name_to_cost;
|
||||
if (cost_graph) {
|
||||
for (auto& node : *cost_graph->mutable_node()) {
|
||||
name_to_cost[node.name()] = &node;
|
||||
}
|
||||
}
|
||||
std::vector<string> inaccurate_nodes;
|
||||
VirtualScheduler scheduler(optimized_graph, item_.fetch);
|
||||
VirtualPlacer placer(cluster_);
|
||||
Costs node_costs;
|
||||
do {
|
||||
const NodeDef* node = scheduler.GetCurrNode();
|
||||
std::vector<OpInfo::TensorProperties> inputs =
|
||||
properties.GetInputProperties(node->name());
|
||||
|
||||
OpInfo::DeviceProperties device = placer.get_device(*node);
|
||||
OpInfo op_info;
|
||||
op_info.set_op(node->op());
|
||||
*op_info.mutable_attr() = node->attr();
|
||||
for (auto& input : inputs) {
|
||||
op_info.add_inputs()->Swap(&input);
|
||||
}
|
||||
op_info.mutable_device()->Swap(&device);
|
||||
|
||||
node_costs = node_estimator_.PredictCosts(op_info);
|
||||
if (node_costs.inaccurate) {
|
||||
inaccurate_nodes.push_back(node->name());
|
||||
}
|
||||
if (cost_graph) {
|
||||
auto it = name_to_cost.find(node->name());
|
||||
CostGraphDef::Node* cost_node;
|
||||
if (it != name_to_cost.end()) {
|
||||
cost_node = it->second;
|
||||
} else {
|
||||
cost_node = cost_graph->add_node();
|
||||
cost_node->set_name(node->name());
|
||||
}
|
||||
string device_name = properties.GetDeviceName(node->name());
|
||||
cost_node->set_device(device_name);
|
||||
cost_node->set_compute_cost(
|
||||
node_costs.execution_time.asMicroSeconds().count());
|
||||
cost_node->set_compute_time(
|
||||
node_costs.compute_time.asMicroSeconds().count());
|
||||
cost_node->set_memory_time(
|
||||
node_costs.memory_time.asMicroSeconds().count());
|
||||
std::vector<OpInfo::TensorProperties> outputs =
|
||||
properties.GetOutputProperties(node->name());
|
||||
for (const auto& output : outputs) {
|
||||
auto output_info = cost_node->add_output_info();
|
||||
output_info->set_dtype(output.dtype());
|
||||
auto shape = output_info->mutable_shape();
|
||||
*shape = output.shape();
|
||||
}
|
||||
}
|
||||
} while (scheduler.MarkCurrNodeExecuted(node_costs));
|
||||
|
||||
*costs = scheduler.Summary();
|
||||
VLOG(1) << inaccurate_nodes.size() << " out of "
|
||||
<< optimized_graph.node_size()
|
||||
<< " nodes have inaccurate time estimation";
|
||||
for (const auto& node : inaccurate_nodes) {
|
||||
VLOG(2) << "Node with inaccurate time estimation: " << node;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
63
tensorflow/core/grappler/costs/analytical_cost_estimator.h
Normal file
63
tensorflow/core/grappler/costs/analytical_cost_estimator.h
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
|
||||
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class CostGraphDef;
|
||||
class GraphDef;
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class Cluster;
|
||||
struct GrapplerItem;
|
||||
|
||||
// Estimate the cost of running a Grappler item based on the theoretical
|
||||
// performance of the hardware that will run the model.
|
||||
class AnalyticalCostEstimator : public CostEstimator {
|
||||
public:
|
||||
// Does not take ownership of cluster.
|
||||
explicit AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
|
||||
~AnalyticalCostEstimator() override {}
|
||||
|
||||
// Initalizes the estimator for the specified grappler item.
|
||||
// This implementation always returns OK.
|
||||
Status Initialize(const GrapplerItem& item) override;
|
||||
|
||||
// Predict the performance of each node of the optimized graph and annotate
|
||||
// the CostGraphDef with the corresponding estimates. Also returns the
|
||||
// expected latency for the whole graph.
|
||||
Status PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph,
|
||||
Costs* overall_latency) const override;
|
||||
|
||||
private:
|
||||
Cluster* cluster_; // Not owned.
|
||||
GrapplerItem item_;
|
||||
OpLevelCostEstimator node_estimator_;
|
||||
bool use_static_shapes_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
|
@ -80,7 +80,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
|
||||
const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo();
|
||||
// Check if vector instructions are available, and refine performance
|
||||
// prediction based on this.
|
||||
gflops = local_cpu.num_cores() * local_cpu.frequency();
|
||||
// Frequencies are stored in MHz in the DeviceProperties.
|
||||
gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3;
|
||||
if (bandwidth < 0) {
|
||||
if (local_cpu.bandwidth() > 0) {
|
||||
bandwidth = local_cpu.bandwidth() / 1e6;
|
||||
@ -105,7 +106,7 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
|
||||
// Pascal.
|
||||
cores_per_multiprocessor = 64;
|
||||
}
|
||||
gflops = local_gpu.num_cores() * local_gpu.frequency() *
|
||||
gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 *
|
||||
cores_per_multiprocessor * kOpsPerMac;
|
||||
if (bandwidth < 0) {
|
||||
CHECK(local_gpu.bandwidth() > 0);
|
||||
|
@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() {
|
||||
// Combine cpu family and model into the model string.
|
||||
device.set_model(
|
||||
strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
|
||||
device.set_frequency(port::NominalCPUFrequency() * 1e-9);
|
||||
device.set_frequency(port::NominalCPUFrequency() * 1e-6);
|
||||
device.set_num_cores(port::NumSchedulableCPUs());
|
||||
device.set_l1_cache_size(Eigen::l1CacheSize());
|
||||
device.set_l2_cache_size(Eigen::l2CacheSize());
|
||||
@ -175,7 +175,7 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) {
|
||||
if (error == cudaSuccess) {
|
||||
device.set_vendor("NVidia");
|
||||
device.set_model(properties.name);
|
||||
device.set_frequency(properties.clockRate / 1000);
|
||||
device.set_frequency(properties.clockRate * 1e-3);
|
||||
device.set_num_cores(properties.multiProcessorCount);
|
||||
device.set_num_registers(properties.regsPerMultiprocessor);
|
||||
// For compute capability less than 5, l1 cache size is configurable to
|
||||
|
57
tensorflow/core/grappler/costs/virtual_placer.cc
Normal file
57
tensorflow/core/grappler/costs/virtual_placer.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
VirtualPlacer::VirtualPlacer(Cluster* cluster) : has_gpu_(false) {
|
||||
devices_["CPU"] = GetLocalCPUInfo();
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
has_gpu_ = true;
|
||||
devices_["GPU"] = GetLocalGPUInfo(0);
|
||||
}
|
||||
unknown_device_.set_type("UNKNOWN");
|
||||
}
|
||||
|
||||
const OpInfo::DeviceProperties& VirtualPlacer::get_device(
|
||||
const NodeDef& node) const {
|
||||
string device_type;
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!node.device().empty() &&
|
||||
DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
|
||||
device_type = parsed.type;
|
||||
} else {
|
||||
if (has_gpu_) {
|
||||
device_type = "GPU";
|
||||
} else {
|
||||
device_type = "CPU";
|
||||
}
|
||||
}
|
||||
auto it = devices_.find(device_type);
|
||||
if (it == devices_.end()) {
|
||||
return unknown_device_;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
45
tensorflow/core/grappler/costs/virtual_placer.h
Normal file
45
tensorflow/core/grappler/costs/virtual_placer.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class NodeDef;
|
||||
|
||||
namespace grappler {
|
||||
class Cluster;
|
||||
|
||||
// The virtual placer emulates the behavior of the TF placer.
|
||||
class VirtualPlacer {
|
||||
public:
|
||||
VirtualPlacer(Cluster* cluster);
|
||||
|
||||
const OpInfo::DeviceProperties& get_device(const NodeDef& node) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<string, OpInfo::DeviceProperties> devices_;
|
||||
bool has_gpu_;
|
||||
OpInfo::DeviceProperties unknown_device_;
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
|
@ -19,6 +19,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/crop_and_resize_op.h"
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -26,10 +29,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
@ -37,41 +43,67 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
using Callback = std::function<void()>;
|
||||
|
||||
static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
|
||||
const Tensor& boxes,
|
||||
const Tensor& box_ind,
|
||||
namespace {
|
||||
|
||||
static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
|
||||
const Tensor& box_index,
|
||||
int* num_boxes) {
|
||||
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
|
||||
if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
|
||||
*num_boxes = 0;
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
// The shape of 'boxes' is [num_boxes, 4].
|
||||
OP_REQUIRES(context, boxes.dims() == 2,
|
||||
errors::InvalidArgument("boxes must be 2-D",
|
||||
boxes.shape().DebugString()));
|
||||
if (boxes.dims() != 2) {
|
||||
return errors::InvalidArgument("boxes must be 2-D",
|
||||
boxes.shape().DebugString());
|
||||
}
|
||||
*num_boxes = boxes.dim_size(0);
|
||||
OP_REQUIRES(context, boxes.dim_size(1) == 4,
|
||||
errors::InvalidArgument("boxes must have 4 columns"));
|
||||
|
||||
// The shape of 'box_ind' is [num_boxes].
|
||||
OP_REQUIRES(context, box_ind.dims() == 1,
|
||||
errors::InvalidArgument("box_ind must be 1-D",
|
||||
box_ind.shape().DebugString()));
|
||||
OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes,
|
||||
errors::InvalidArgument("box_ind has incompatible shape"));
|
||||
if (boxes.dim_size(1) != 4) {
|
||||
return errors::InvalidArgument("boxes must have 4 columns");
|
||||
}
|
||||
// The shape of 'box_index' is [num_boxes].
|
||||
if (box_index.dims() != 1) {
|
||||
return errors::InvalidArgument("box_index must be 1-D",
|
||||
box_index.shape().DebugString());
|
||||
}
|
||||
if (box_index.dim_size(0) != *num_boxes) {
|
||||
return errors::InvalidArgument("box_index has incompatible shape");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Verifies that all values in box_ind are in [0, batch).
|
||||
// Conditionally calls the compute callback if all values in box_index are in
|
||||
// [0, batch_size) then calls done.
|
||||
template <typename Device>
|
||||
inline void CheckValidBoxInd(
|
||||
OpKernelContext* context,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch);
|
||||
inline void RunIfBoxIndexIsValid(
|
||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
int batch_size, Callback compute, Callback done);
|
||||
|
||||
// Specialization of CheckValidBoxIndex for a CPUDevice.
|
||||
template <>
|
||||
inline void RunIfBoxIndexIsValid<CPUDevice>(
|
||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
int batch_size, Callback compute, Callback done) {
|
||||
const int num_boxes = box_index.dimension(0);
|
||||
for (int b = 0; b < num_boxes; ++b) {
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, FastBoundsCheck(box_index(b), batch_size),
|
||||
errors::OutOfRange("box_index has values outside [0, batch_size)"),
|
||||
done);
|
||||
}
|
||||
compute();
|
||||
done();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CropAndResizeOp : public OpKernel {
|
||||
class CropAndResizeOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
explicit CropAndResizeOp(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context) {
|
||||
string method;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
||||
OP_REQUIRES(context, method == "bilinear",
|
||||
@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
|
||||
&extrapolation_value_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// The shape of 'image' is [batch, image_height, image_width, channels].
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||
// The shape of 'image' is [batch_size, image_height, image_width,
|
||||
// channels].
|
||||
const Tensor& image = context->input(0);
|
||||
OP_REQUIRES(context, image.dims() == 4,
|
||||
errors::InvalidArgument("input image must be 4-D",
|
||||
image.shape().DebugString()));
|
||||
|
||||
const int batch = image.dim_size(0);
|
||||
const int image_height = image.dim_size(1);
|
||||
const int image_width = image.dim_size(2);
|
||||
const int depth = image.dim_size(3);
|
||||
OP_REQUIRES(context, image_height > 0 && image_width > 0,
|
||||
errors::InvalidArgument("image dimensions must be positive"));
|
||||
|
||||
// The shape of 'boxes' is [num_boxes, 4].
|
||||
const Tensor& boxes = context->input(1);
|
||||
|
||||
// The shape of 'box_ind' is [num_boxes].
|
||||
const Tensor& box_ind = context->input(2);
|
||||
|
||||
int num_boxes = 0;
|
||||
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
|
||||
|
||||
// The shape of 'box_index' is [num_boxes].
|
||||
const Tensor& box_index = context->input(2);
|
||||
// The shape of 'crop_size' is [2].
|
||||
const Tensor& crop_size = context->input(3);
|
||||
|
||||
OP_REQUIRES(context, crop_size.dims() == 1,
|
||||
errors::InvalidArgument("crop_size must be 1-D",
|
||||
crop_size.shape().DebugString()));
|
||||
OP_REQUIRES(context, crop_size.dim_size(0) == 2,
|
||||
errors::InvalidArgument("crop_size must have two elements",
|
||||
crop_size.shape().DebugString()));
|
||||
// Validate inputs dimensions.
|
||||
OP_REQUIRES_ASYNC(context, image.dims() == 4,
|
||||
errors::InvalidArgument("input image must be 4-D",
|
||||
image.shape().DebugString()),
|
||||
done);
|
||||
const int batch_size = image.dim_size(0);
|
||||
const int image_height = image.dim_size(1);
|
||||
const int image_width = image.dim_size(2);
|
||||
const int depth = image.dim_size(3);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, image_height > 0 && image_width > 0,
|
||||
errors::InvalidArgument("image dimensions must be positive"), done);
|
||||
int num_boxes = 0;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
|
||||
|
||||
OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
|
||||
errors::InvalidArgument("crop_size must be 1-D",
|
||||
crop_size.shape().DebugString()),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, crop_size.dim_size(0) == 2,
|
||||
errors::InvalidArgument("crop_size must have two elements",
|
||||
crop_size.shape().DebugString()),
|
||||
done);
|
||||
|
||||
// Copy and validate crop sizes.
|
||||
auto crop_size_vec = crop_size.vec<int32>();
|
||||
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
|
||||
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
|
||||
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
|
||||
errors::InvalidArgument("crop dimensions must be positive"));
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, crop_height > 0 && crop_width > 0,
|
||||
errors::InvalidArgument("crop dimensions must be positive"), done);
|
||||
|
||||
// Allocate output tensor.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_output(
|
||||
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
|
||||
&output));
|
||||
&output),
|
||||
done);
|
||||
|
||||
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
|
||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
||||
boxes.tensor<float, 2>();
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind_data =
|
||||
box_ind.tensor<int32, 1>();
|
||||
typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>();
|
||||
|
||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
||||
|
||||
bool status = functor::CropAndResize<Device, T>()(
|
||||
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
|
||||
extrapolation_value_, crops_data);
|
||||
auto compute_callback = [this, context, output]() {
|
||||
const Tensor& image = context->input(0);
|
||||
const Tensor& boxes = context->input(1);
|
||||
const Tensor& box_index = context->input(2);
|
||||
const bool status = functor::CropAndResize<Device, T>()(
|
||||
context->eigen_device<Device>(), image.tensor<T, 4>(),
|
||||
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
|
||||
extrapolation_value_, output->tensor<float, 4>());
|
||||
if (!status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed launch CropAndResizeKernel."));
|
||||
}
|
||||
};
|
||||
|
||||
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
|
||||
batch_size, std::move(compute_callback),
|
||||
std::move(done));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -155,10 +195,10 @@ template <typename T>
|
||||
struct CropAndResize<CPUDevice, T> {
|
||||
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
float extrapolation_value,
|
||||
typename TTypes<float, 4>::Tensor crops) {
|
||||
const int batch = image.dimension(0);
|
||||
const int batch_size = image.dimension(0);
|
||||
const int image_height = image.dimension(1);
|
||||
const int image_width = image.dimension(2);
|
||||
|
||||
@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
|
||||
const float y2 = boxes(b, 2);
|
||||
const float x2 = boxes(b, 3);
|
||||
|
||||
const int32 b_in = box_ind(b);
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
const int32 b_in = box_index(b);
|
||||
if (!FastBoundsCheck(b_in, batch_size)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CropAndResizeGradImageOp : public OpKernel {
|
||||
class CropAndResizeGradImageOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
: AsyncOpKernel(context) {
|
||||
string method;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
||||
OP_REQUIRES(context, method == "bilinear",
|
||||
errors::InvalidArgument("method must be 'bilinear'", method));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
||||
const Tensor& grads = context->input(0);
|
||||
|
||||
OP_REQUIRES(context, grads.dims() == 4,
|
||||
errors::InvalidArgument("grads image must be 4-D",
|
||||
grads.shape().DebugString()));
|
||||
const int crop_height = grads.dim_size(1);
|
||||
const int crop_width = grads.dim_size(2);
|
||||
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
|
||||
errors::InvalidArgument("grads dimensions must be positive"));
|
||||
|
||||
// The shape of 'boxes' is [num_boxes, 4].
|
||||
const Tensor& boxes = context->input(1);
|
||||
|
||||
// The shape of 'box_ind' is [num_boxes].
|
||||
const Tensor& box_ind = context->input(2);
|
||||
|
||||
int num_boxes = 0;
|
||||
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
|
||||
|
||||
OP_REQUIRES(
|
||||
context, grads.dim_size(0) == num_boxes,
|
||||
errors::InvalidArgument("boxes and grads have incompatible shape"));
|
||||
|
||||
// The shape of 'box_index' is [num_boxes].
|
||||
const Tensor& box_index = context->input(2);
|
||||
// The shape of 'image_size' is [4].
|
||||
const Tensor& image_size = context->input(3);
|
||||
OP_REQUIRES(context, image_size.dims() == 1,
|
||||
errors::InvalidArgument("image_size must be 1-D",
|
||||
image_size.shape().DebugString()));
|
||||
OP_REQUIRES(context, image_size.dim_size(0) == 4,
|
||||
errors::InvalidArgument("image_size must have 4 elements",
|
||||
image_size.shape().DebugString()));
|
||||
|
||||
// Validate input shapes.
|
||||
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
|
||||
errors::InvalidArgument("grads image must be 4-D",
|
||||
grads.shape().DebugString()),
|
||||
done);
|
||||
const int crop_height = grads.dim_size(1);
|
||||
const int crop_width = grads.dim_size(2);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, crop_height > 0 && crop_width > 0,
|
||||
errors::InvalidArgument("grads dimensions must be positive"), done);
|
||||
int num_boxes = 0;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, grads.dim_size(0) == num_boxes,
|
||||
errors::InvalidArgument("boxes and grads have incompatible shape"),
|
||||
done);
|
||||
|
||||
OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
|
||||
errors::InvalidArgument("image_size must be 1-D",
|
||||
image_size.shape().DebugString()),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
|
||||
errors::InvalidArgument("image_size must have 4 elements",
|
||||
image_size.shape().DebugString()),
|
||||
done);
|
||||
auto image_size_vec = image_size.vec<int32>();
|
||||
const int batch = internal::SubtleMustCopy(image_size_vec(0));
|
||||
const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
|
||||
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
|
||||
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
|
||||
const int depth = internal::SubtleMustCopy(image_size_vec(3));
|
||||
|
||||
OP_REQUIRES(context, image_height > 0 && image_width > 0,
|
||||
errors::InvalidArgument("image dimensions must be positive"));
|
||||
OP_REQUIRES(
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, image_height > 0 && image_width > 0,
|
||||
errors::InvalidArgument("image dimensions must be positive"), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, grads.dim_size(3) == depth,
|
||||
errors::InvalidArgument("image_size and grads are incompatible"));
|
||||
errors::InvalidArgument("image_size and grads are incompatible"), done);
|
||||
|
||||
// Allocate output tensor.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(
|
||||
0, TensorShape({batch, image_height, image_width, depth}),
|
||||
&output));
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_output(
|
||||
0, TensorShape({batch_size, image_height, image_width, depth}),
|
||||
&output),
|
||||
done);
|
||||
|
||||
typename TTypes<float, 4>::ConstTensor grads_data =
|
||||
grads.tensor<float, 4>();
|
||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
||||
boxes.tensor<float, 2>();
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind_data =
|
||||
box_ind.tensor<int32, 1>();
|
||||
typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
|
||||
|
||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
||||
|
||||
bool status = functor::CropAndResizeBackpropImage<Device, T>()(
|
||||
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
|
||||
output_data);
|
||||
auto compute_callback = [context, output]() {
|
||||
const Tensor& grads = context->input(0);
|
||||
const Tensor& boxes = context->input(1);
|
||||
const Tensor& box_index = context->input(2);
|
||||
const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
|
||||
context->eigen_device<Device>(), grads.tensor<float, 4>(),
|
||||
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
|
||||
output->tensor<T, 4>());
|
||||
if (!status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
|
||||
context->SetStatus(errors::Internal(
|
||||
"Failed launch CropAndResizeBackpropImage kernel."));
|
||||
}
|
||||
};
|
||||
|
||||
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
|
||||
batch_size, std::move(compute_callback),
|
||||
std::move(done));
|
||||
}
|
||||
};
|
||||
|
||||
@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
||||
bool operator()(const CPUDevice& d,
|
||||
typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
typename TTypes<T, 4>::Tensor grads_image) {
|
||||
const int batch = grads_image.dimension(0);
|
||||
const int batch_size = grads_image.dimension(0);
|
||||
const int image_height = grads_image.dimension(1);
|
||||
const int image_width = grads_image.dimension(2);
|
||||
|
||||
@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
||||
const float y2 = boxes(b, 2);
|
||||
const float x2 = boxes(b, 3);
|
||||
|
||||
const int32 b_in = box_ind(b);
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
const int32 b_in = box_index(b);
|
||||
if (!FastBoundsCheck(b_in, batch_size)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CropAndResizeGradBoxesOp : public OpKernel {
|
||||
class CropAndResizeGradBoxesOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
: AsyncOpKernel(context) {
|
||||
string method;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
||||
OP_REQUIRES(context, method == "bilinear",
|
||||
errors::InvalidArgument("method must be 'bilinear'", method));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
||||
const Tensor& grads = context->input(0);
|
||||
// The shape of 'boxes' is [num_boxes, 4].
|
||||
const Tensor& boxes = context->input(2);
|
||||
// The shape of 'box_index' is [num_boxes].
|
||||
const Tensor& box_index = context->input(3);
|
||||
// The shape of 'image' is [batch_size, image_height, image_width, depth].
|
||||
const Tensor& image = context->input(1);
|
||||
|
||||
OP_REQUIRES(context, grads.dims() == 4,
|
||||
// Validate input shapes.
|
||||
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
|
||||
errors::InvalidArgument("grads image must be 4-D",
|
||||
grads.shape().DebugString()));
|
||||
|
||||
grads.shape().DebugString()),
|
||||
done);
|
||||
const int crop_height = grads.dim_size(1);
|
||||
const int crop_width = grads.dim_size(2);
|
||||
const int depth = grads.dim_size(3);
|
||||
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
|
||||
errors::InvalidArgument("grads dimensions must be positive"));
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, crop_height > 0 && crop_width > 0,
|
||||
errors::InvalidArgument("grads dimensions must be positive"), done);
|
||||
|
||||
// The shape of 'image' is [batch, image_height, image_width, depth].
|
||||
const Tensor& image = context->input(1);
|
||||
OP_REQUIRES(context, image.dims() == 4,
|
||||
OP_REQUIRES_ASYNC(context, image.dims() == 4,
|
||||
errors::InvalidArgument("input image must be 4-D",
|
||||
image.shape().DebugString()));
|
||||
|
||||
const int batch = image.dim_size(0);
|
||||
image.shape().DebugString()),
|
||||
done);
|
||||
const int batch_size = image.dim_size(0);
|
||||
const int image_height = image.dim_size(1);
|
||||
const int image_width = image.dim_size(2);
|
||||
OP_REQUIRES(context, image_height > 0 && image_width > 0,
|
||||
errors::InvalidArgument("image dimensions must be positive"));
|
||||
OP_REQUIRES(context, image.dim_size(3) == depth,
|
||||
errors::InvalidArgument("image, grads depth differ"));
|
||||
|
||||
// The shape of 'boxes' is [num_boxes, 4].
|
||||
const Tensor& boxes = context->input(2);
|
||||
|
||||
// The shape of 'box_ind' is [num_boxes].
|
||||
const Tensor& box_ind = context->input(3);
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, image_height > 0 && image_width > 0,
|
||||
errors::InvalidArgument("image dimensions must be positive"), done);
|
||||
OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
|
||||
errors::InvalidArgument("image, grads depth differ"),
|
||||
done);
|
||||
|
||||
int num_boxes = 0;
|
||||
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
|
||||
|
||||
OP_REQUIRES(
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, grads.dim_size(0) == num_boxes,
|
||||
errors::InvalidArgument("boxes and grads have incompatible shape"));
|
||||
errors::InvalidArgument("boxes and grads have incompatible shape"),
|
||||
done);
|
||||
|
||||
// Allocate output tensor.
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, TensorShape({num_boxes, 4}), &output));
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
|
||||
done);
|
||||
|
||||
typename TTypes<float, 4>::ConstTensor grads_data =
|
||||
grads.tensor<float, 4>();
|
||||
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
|
||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
||||
boxes.tensor<float, 2>();
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind_data =
|
||||
box_ind.tensor<int32, 1>();
|
||||
typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>();
|
||||
|
||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
||||
|
||||
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
|
||||
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
|
||||
box_ind_data, output_data);
|
||||
auto compute_callback = [context, output]() {
|
||||
const Tensor& grads = context->input(0);
|
||||
const Tensor& image = context->input(1);
|
||||
const Tensor& boxes = context->input(2);
|
||||
const Tensor& box_index = context->input(3);
|
||||
const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
|
||||
context->eigen_device<Device>(), grads.tensor<float, 4>(),
|
||||
image.tensor<T, 4>(), boxes.tensor<float, 2>(),
|
||||
box_index.tensor<int32, 1>(), output->tensor<float, 2>());
|
||||
if (!status) {
|
||||
context->SetStatus(
|
||||
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
|
||||
context->SetStatus(errors::Internal(
|
||||
"Failed launch CropAndResizeBackpropBoxes kernel."));
|
||||
}
|
||||
};
|
||||
|
||||
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
|
||||
batch_size, std::move(compute_callback),
|
||||
std::move(done));
|
||||
}
|
||||
};
|
||||
|
||||
@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
||||
typename TTypes<float, 4>::ConstTensor grads,
|
||||
typename TTypes<T, 4>::ConstTensor image,
|
||||
typename TTypes<float, 2>::ConstTensor boxes,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
typename TTypes<float, 2>::Tensor grads_boxes) {
|
||||
const int batch = image.dimension(0);
|
||||
const int batch_size = image.dimension(0);
|
||||
const int image_height = image.dimension(1);
|
||||
const int image_width = image.dimension(2);
|
||||
|
||||
@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
||||
const float y2 = boxes(b, 2);
|
||||
const float x2 = boxes(b, 3);
|
||||
|
||||
const int32 b_in = box_ind(b);
|
||||
if (b_in < 0 || b_in >= batch) {
|
||||
const int32 b_in = box_index(b);
|
||||
if (!FastBoundsCheck(b_in, batch_size)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -589,19 +641,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
// Specialization of CheckValidBoxInd for a CPUDevice.
|
||||
template <>
|
||||
inline void CheckValidBoxInd<CPUDevice>(
|
||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
int batch) {
|
||||
const int num_boxes = box_ind.dimension(0);
|
||||
for (int b = 0; b < num_boxes; ++b) {
|
||||
OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch,
|
||||
errors::OutOfRange("box_ind has values outside [0, batch)"));
|
||||
}
|
||||
}
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
||||
@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// Forward declaration of the CheckValidBoxIndHelper specialization for GPU.
|
||||
// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
|
||||
namespace functor {
|
||||
template <>
|
||||
void CheckValidBoxIndHelper<GPUDevice>::operator()(
|
||||
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
int batch, typename TTypes<bool, 0>::Tensor isvalid);
|
||||
extern template struct CheckValidBoxIndHelper<GPUDevice>;
|
||||
void CheckValidBoxIndexHelper<GPUDevice>::operator()(
|
||||
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
|
||||
extern template struct CheckValidBoxIndexHelper<GPUDevice>;
|
||||
} // namespace functor
|
||||
|
||||
// Specialization of CheckValidBoxInd for a GPUDevice.
|
||||
namespace {
|
||||
|
||||
// Specialization of CheckValidBoxIndex for a GPUDevice.
|
||||
template <>
|
||||
inline void CheckValidBoxInd<GPUDevice>(
|
||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
|
||||
int batch) {
|
||||
const int num_boxes = box_ind.dimension(0);
|
||||
inline void RunIfBoxIndexIsValid<GPUDevice>(
|
||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||
int batch_size, Callback compute, Callback done) {
|
||||
const int num_boxes = box_index.dimension(0);
|
||||
if (num_boxes == 0) {
|
||||
compute();
|
||||
done();
|
||||
return;
|
||||
}
|
||||
Tensor isvalid_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<bool>::value,
|
||||
TensorShape({}), &isvalid_tensor));
|
||||
|
||||
typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>();
|
||||
Tensor isvalid_dev_tensor;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
|
||||
&isvalid_dev_tensor),
|
||||
done);
|
||||
typename TTypes<bool, 0>::Tensor isvalid_dev =
|
||||
isvalid_dev_tensor.tensor<bool, 0>();
|
||||
|
||||
functor::CheckValidBoxIndHelper<GPUDevice>()(
|
||||
context->eigen_device<GPUDevice>(), box_ind, batch, isvalid);
|
||||
// Run the actual box check on the device.
|
||||
functor::CheckValidBoxIndexHelper<GPUDevice>()(
|
||||
context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
|
||||
|
||||
// Copy the result back to the host.
|
||||
auto* stream = context->op_device_context()->stream();
|
||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
||||
|
||||
bool isvalid_host = false;
|
||||
perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(),
|
||||
OP_REQUIRES_ASYNC(context, stream,
|
||||
errors::Internal("No GPU stream available."), done);
|
||||
Tensor isvalid_host_tensor;
|
||||
// Use pinned host memory on the host to avoid unnecessary
|
||||
// synchronization.
|
||||
AllocatorAttributes alloc_attr;
|
||||
alloc_attr.set_on_host(true);
|
||||
alloc_attr.set_gpu_compatible(true);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
context,
|
||||
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
|
||||
&isvalid_host_tensor, alloc_attr),
|
||||
done);
|
||||
perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
|
||||
sizeof(bool));
|
||||
stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool));
|
||||
stream->BlockHostUntilDone();
|
||||
const bool status =
|
||||
stream
|
||||
->ThenMemcpy(
|
||||
isvalid_host_tensor.scalar<bool>().data() /* destination */,
|
||||
wrapped /* source */, sizeof(bool))
|
||||
.ok();
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, status,
|
||||
errors::Internal("Failed to launch copy of isvalid from device to host."),
|
||||
done);
|
||||
|
||||
OP_REQUIRES(context, stream->ok(),
|
||||
errors::Internal("cudaMemcpy from device to host failed"));
|
||||
auto wrapped_callback = [context, isvalid_host_tensor, compute, done]() {
|
||||
const bool isvalid = isvalid_host_tensor.scalar<bool>()();
|
||||
OP_REQUIRES_ASYNC(
|
||||
context, isvalid,
|
||||
errors::OutOfRange("box_index has values outside [0, batch_size)"),
|
||||
done);
|
||||
compute();
|
||||
done();
|
||||
};
|
||||
|
||||
OP_REQUIRES(context, isvalid_host,
|
||||
errors::OutOfRange("box_ind has values outside [0, batch)"));
|
||||
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||
stream, wrapped_callback);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
||||
.Device(DEVICE_GPU) \
|
||||
|
@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
struct CheckValidBoxIndHelper {
|
||||
// Checks if all values in box_ind are in [0, batch).
|
||||
struct CheckValidBoxIndexHelper {
|
||||
// Checks if all values in box_index are in [0, batch).
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<int32, 1>::ConstTensor box_ind, int batch,
|
||||
typename TTypes<int32, 1>::ConstTensor box_index, int batch,
|
||||
typename TTypes<bool, 0>::Tensor isvalid) {
|
||||
isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all();
|
||||
isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
|
||||
template struct CheckValidBoxIndHelper<GPUDevice>;
|
||||
template struct CheckValidBoxIndexHelper<GPUDevice>;
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
|
||||
Status s = RunOpKernel();
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
|
||||
StringPiece(s.ToString()).contains("box_index has incompatible shape"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
|
||||
Status s = RunOpKernel();
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_TRUE(StringPiece(s.ToString())
|
||||
.contains("box_ind has values outside [0, batch)"))
|
||||
.contains("box_index has values outside [0, batch_size)"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
// TODO(zhengxq, rmlarsen): Add a benchmark.
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -155,7 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
|
||||
const int col_dimension = input_rank - 1;
|
||||
const int64 num_rows = in.dim_size(row_dimension);
|
||||
const int64 num_cols = in.dim_size(col_dimension);
|
||||
input_matrix_shapes->emplace_back(std::initializer_list<int64>({num_rows, num_cols}));
|
||||
input_matrix_shapes->emplace_back(
|
||||
std::initializer_list<int64>({num_rows, num_cols}));
|
||||
inputs->emplace_back(&in);
|
||||
}
|
||||
// Have the derived class validate that the inputs are as expected.
|
||||
@ -233,8 +234,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
|
||||
matrix_inputs.emplace_back(
|
||||
inputs[i]->flat<Scalar>().data() +
|
||||
matrix_index * input_matrix_shapes[i].num_elements(),
|
||||
input_matrix_shapes[i].dim_size(0),
|
||||
input_matrix_shapes[i].dim_size(1));
|
||||
input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
|
||||
}
|
||||
|
||||
MatrixMaps matrix_outputs;
|
||||
|
@ -1716,6 +1716,31 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Atan2"
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "z"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "AudioSpectrogram"
|
||||
input_arg {
|
||||
|
@ -1904,6 +1904,33 @@ op {
|
||||
}
|
||||
summary: "Computes atan of x element-wise."
|
||||
}
|
||||
op {
|
||||
name: "Atan2"
|
||||
input_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "z"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes arctangent of `y/x` element-wise, respecting signs of the arguments."
|
||||
description: "This is the angle \\( \\theta \\in [-\\pi, \\pi] \\) such that\n\\[ x = r \\cos(\\theta) \\]\nand\n\\[ y = r \\sin(\\theta) \\]\nwhere \\(r = \\sqrt(x^2 + y^2) \\)."
|
||||
}
|
||||
op {
|
||||
name: "AudioSpectrogram"
|
||||
input_arg {
|
||||
|
@ -1,33 +0,0 @@
|
||||
# Random variable transformations (contrib)
|
||||
[TOC]
|
||||
|
||||
Bijector Ops.
|
||||
|
||||
An API for invertible, differentiable transformations of random variables.
|
||||
|
||||
## Background
|
||||
|
||||
Differentiable, bijective transformations of continuous random variables alter
|
||||
the calculations made in the cumulative/probability distribution functions and
|
||||
sample function. This module provides a standard interface for making these
|
||||
manipulations.
|
||||
|
||||
For more details and examples, see the `Bijector` docstring.
|
||||
|
||||
To apply a `Bijector`, use `distributions.TransformedDistribution`.
|
||||
|
||||
## Bijectors
|
||||
|
||||
* @{tf.contrib.distributions.bijector.Affine}
|
||||
* @{tf.contrib.distributions.bijector.AffineLinearOperator}
|
||||
* @{tf.contrib.distributions.bijector.Bijector}
|
||||
* @{tf.contrib.distributions.bijector.Chain}
|
||||
* @{tf.contrib.distributions.bijector.CholeskyOuterProduct}
|
||||
* @{tf.contrib.distributions.bijector.Exp}
|
||||
* @{tf.contrib.distributions.bijector.Identity}
|
||||
* @{tf.contrib.distributions.bijector.Inline}
|
||||
* @{tf.contrib.distributions.bijector.Invert}
|
||||
* @{tf.contrib.distributions.bijector.PowerTransform}
|
||||
* @{tf.contrib.distributions.bijector.SigmoidCentered}
|
||||
* @{tf.contrib.distributions.bijector.SoftmaxCentered}
|
||||
* @{tf.contrib.distributions.bijector.Softplus}
|
@ -0,0 +1,33 @@
|
||||
# Random variable transformations (contrib)
|
||||
[TOC]
|
||||
|
||||
Bijector Ops.
|
||||
|
||||
An API for invertible, differentiable transformations of random variables.
|
||||
|
||||
## Background
|
||||
|
||||
Differentiable, bijective transformations of continuous random variables alter
|
||||
the calculations made in the cumulative/probability distribution functions and
|
||||
sample function. This module provides a standard interface for making these
|
||||
manipulations.
|
||||
|
||||
For more details and examples, see the `Bijector` docstring.
|
||||
|
||||
To apply a `Bijector`, use `distributions.TransformedDistribution`.
|
||||
|
||||
## Bijectors
|
||||
|
||||
* @{tf.contrib.distributions.bijectors.Affine}
|
||||
* @{tf.contrib.distributions.bijectors.AffineLinearOperator}
|
||||
* @{tf.contrib.distributions.bijectors.Bijector}
|
||||
* @{tf.contrib.distributions.bijectors.Chain}
|
||||
* @{tf.contrib.distributions.bijectors.CholeskyOuterProduct}
|
||||
* @{tf.contrib.distributions.bijectors.Exp}
|
||||
* @{tf.contrib.distributions.bijectors.Identity}
|
||||
* @{tf.contrib.distributions.bijectors.Inline}
|
||||
* @{tf.contrib.distributions.bijectors.Invert}
|
||||
* @{tf.contrib.distributions.bijectors.PowerTransform}
|
||||
* @{tf.contrib.distributions.bijectors.SigmoidCentered}
|
||||
* @{tf.contrib.distributions.bijectors.SoftmaxCentered}
|
||||
* @{tf.contrib.distributions.bijectors.Softplus}
|
@ -76,7 +76,7 @@ representing the posterior or posterior predictive.
|
||||
|
||||
## Kullback-Leibler Divergence
|
||||
|
||||
* @{tf.contrib.distributions.kl}
|
||||
* @{tf.contrib.distributions.kl_divergence}
|
||||
* @{tf.contrib.distributions.RegisterKL}
|
||||
|
||||
## Utilities
|
||||
|
@ -40,7 +40,7 @@
|
||||
* [Losses (contrib)](contrib.losses.md)
|
||||
* [Metrics (contrib)](contrib.metrics.md)
|
||||
* [Optimization (contrib)](contrib.opt.md)
|
||||
* [Random variable transformations (contrib)](contrib.distributions.bijector.md)
|
||||
* [Random variable transformations (contrib)](contrib.distributions.bijectors.md)
|
||||
* [RNN and Cells (contrib)](contrib.rnn.md)
|
||||
* [Seq2seq Library (contrib)](contrib.seq2seq.md)
|
||||
* [Staging (contrib)](contrib.staging.md)
|
||||
|
@ -80,10 +80,12 @@ section.
|
||||
* **OS:** Ubuntu 16.04 LTS with tests run via Docker
|
||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||
* **TensorFlow GitHub hash:** b1e174e
|
||||
* **Benchmark GitHub hash:** 9165a70
|
||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||
//tensorflow/tools/pip_package:build_pip_package`
|
||||
* **Disk:** Local SSD
|
||||
* **DataSet:** ImageNet
|
||||
* **Test Date:** May 2017
|
||||
|
||||
Batch size and optimizer used for each model are listed in the table below. In
|
||||
addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
|
||||
@ -120,19 +122,19 @@ VGG16 | replicated (with NCCL) | n/a
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 142 | 238 | 95.6 | 2987 | 154
|
||||
2 | 284 | 479 | 187 | 5658 | 295
|
||||
4 | 569 | 948 | 374 | 10509 | 584
|
||||
8 | 1131 | 1886 | 744 | 17822 | 1081
|
||||
1 | 142 | 219 | 91.8 | 2987 | 154
|
||||
2 | 284 | 422 | 181 | 5658 | 295
|
||||
4 | 569 | 852 | 356 | 10509 | 584
|
||||
8 | 1131 | 1734 | 716 | 17822 | 1081
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 142 | 239 | 95.5 | 2890 | 154
|
||||
2 | 278 | 468 | 187 | 4448 | 284
|
||||
4 | 551 | 938 | 373 | 7105 | 534
|
||||
8 | 1079 | 1802 | 721 | N/A | 898
|
||||
1 | 142 | 218 | 91.4 | 2890 | 154
|
||||
2 | 278 | 425 | 179 | 4448 | 284
|
||||
4 | 551 | 853 | 359 | 7105 | 534
|
||||
8 | 1079 | 1630 | 708 | N/A | 898
|
||||
|
||||
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
||||
above due to it maxing out the input pipeline.
|
||||
@ -145,19 +147,19 @@ The results below are all with a batch size of 32.
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||
---- | ----------- | --------- | ---------- | -----
|
||||
1 | 128 | 210 | 85.3 | 144
|
||||
2 | 259 | 412 | 166 | 281
|
||||
4 | 520 | 827 | 330 | 549
|
||||
8 | 995 | 1623 | 643 | 820
|
||||
1 | 128 | 195 | 82.7 | 144
|
||||
2 | 259 | 368 | 160 | 281
|
||||
4 | 520 | 768 | 317 | 549
|
||||
8 | 995 | 1485 | 632 | 820
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||
---- | ----------- | --------- | ---------- | -----
|
||||
1 | 130 | 208 | 85.0 | 144
|
||||
2 | 257 | 403 | 163 | 253
|
||||
4 | 507 | 814 | 325 | 457
|
||||
8 | 966 | 1525 | 641 | 690
|
||||
1 | 130 | 193 | 82.4 | 144
|
||||
2 | 257 | 369 | 159 | 253
|
||||
4 | 507 | 760 | 317 | 457
|
||||
8 | 966 | 1410 | 609 | 690
|
||||
|
||||
## Details for Google Compute Engine (NVIDIA® Tesla® K80)
|
||||
|
||||
@ -168,11 +170,12 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||
* **OS:** Ubuntu 16.04 LTS
|
||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||
* **TensorFlow GitHub hash:** b1e174e
|
||||
* **Benchmark GitHub hash:** 9165a70
|
||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||
//tensorflow/tools/pip_package:build_pip_package`
|
||||
* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
|
||||
* **DataSet:** ImageNet
|
||||
* **Test Date:** April 2017
|
||||
* **Test Date:** May 2017
|
||||
|
||||
Batch size and optimizer used for each model are listed in the table below. In
|
||||
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
||||
@ -198,19 +201,19 @@ The configuration used for each model was `variable_update` equal to
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.5 | 56.8 | 20.8 | 656 | 35.4
|
||||
2 | 57.8 | 107 | 39.1 | 1209 | 64.8
|
||||
4 | 116 | 212 | 77.2 | 2328 | 120
|
||||
8 | 227 | 419 | 151 | 4640 | 234
|
||||
1 | 30.5 | 51.9 | 20.0 | 656 | 35.4
|
||||
2 | 57.8 | 99.0 | 38.2 | 1209 | 64.8
|
||||
4 | 116 | 195 | 75.8 | 2328 | 120
|
||||
8 | 227 | 387 | 148 | 4640 | 234
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.6 | 56.7 | 20.7 | 639 | 34.2
|
||||
2 | 58.4 | 107 | 39.0 | 1136 | 62.9
|
||||
4 | 115 | 211 | 77.3 | 2067 | 118
|
||||
8 | 225 | 422 | 151 | 4056 | 230
|
||||
1 | 30.6 | 51.2 | 20.0 | 639 | 34.2
|
||||
2 | 58.4 | 98.8 | 38.3 | 1136 | 62.9
|
||||
4 | 115 | 194 | 75.4 | 2067 | 118
|
||||
8 | 225 | 381 | 148 | 4056 | 230
|
||||
|
||||
### Other Results
|
||||
|
||||
@ -218,19 +221,19 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
|
||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
---- | --------------------------- | -------------------------
|
||||
1 | 29.3 | 53.9
|
||||
2 | 55.0 | 101
|
||||
4 | 109 | 200
|
||||
8 | 216 | 398
|
||||
1 | 29.3 | 49.5
|
||||
2 | 55.0 | 95.4
|
||||
4 | 109 | 183
|
||||
8 | 216 | 362
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
---- | --------------------------- | -------------------------
|
||||
1 | 29.5 | 53.6
|
||||
2 | 55.4 | 102
|
||||
4 | 110 | 201
|
||||
8 | 216 | 387
|
||||
1 | 29.5 | 49.3
|
||||
2 | 55.4 | 95.3
|
||||
4 | 110 | 186
|
||||
8 | 216 | 359
|
||||
|
||||
## Details for Amazon EC2 (NVIDIA® Tesla® K80)
|
||||
|
||||
@ -241,12 +244,13 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
* **OS:** Ubuntu 16.04 LTS
|
||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||
* **TensorFlow GitHub hash:** b1e174e
|
||||
* **Benchmark GitHub hash:** 9165a70
|
||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||
//tensorflow/tools/pip_package:build_pip_package`
|
||||
* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
|
||||
MiB/sec)
|
||||
* **DataSet:** ImageNet
|
||||
* **Test Date:** April 2017
|
||||
* **Test Date:** May 2017
|
||||
|
||||
Batch size and optimizer used for each model are listed in the table below. In
|
||||
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
||||
@ -279,19 +283,19 @@ VGG16 | parameter_server | gpu
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.8 | 56.3 | 20.9 | 684 | 36.3
|
||||
2 | 58.7 | 108 | 39.3 | 1244 | 69.4
|
||||
4 | 117 | 217 | 79.1 | 2479 | 141
|
||||
8 | 230 | 419 | 156 | 4853 | 260
|
||||
1 | 30.8 | 51.5 | 19.7 | 684 | 36.3
|
||||
2 | 58.7 | 98.0 | 37.6 | 1244 | 69.4
|
||||
4 | 117 | 195 | 74.9 | 2479 | 141
|
||||
8 | 230 | 384 | 149 | 4853 | 260
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||
---- | ----------- | --------- | ---------- | ------- | -----
|
||||
1 | 30.5 | 56.0 | 20.6 | 674 | 36.3
|
||||
2 | 59.0 | 107 | 39.0 | 1227 | 67.5
|
||||
4 | 118 | 205 | 77.9 | 2201 | 136
|
||||
8 | 228 | 405 | 152 | N/A | 242
|
||||
1 | 30.5 | 51.3 | 19.7 | 674 | 36.3
|
||||
2 | 59.0 | 94.9 | 38.2 | 1227 | 67.5
|
||||
4 | 118 | 188 | 75.2 | 2201 | 136
|
||||
8 | 228 | 373 | 149 | N/A | 242
|
||||
|
||||
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
||||
above due to our EFS setup not providing enough throughput.
|
||||
@ -302,19 +306,19 @@ above due to our EFS setup not providing enough throughput.
|
||||
|
||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
---- | --------------------------- | -------------------------
|
||||
1 | 29.9 | 53.5
|
||||
2 | 57.5 | 101
|
||||
4 | 114 | 202
|
||||
8 | 216 | 380
|
||||
1 | 29.9 | 49.0
|
||||
2 | 57.5 | 94.1
|
||||
4 | 114 | 184
|
||||
8 | 216 | 355
|
||||
|
||||
**Training real data**
|
||||
|
||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
---- | --------------------------- | -------------------------
|
||||
1 | 30.0 | 53.6
|
||||
2 | 57.5 | 102
|
||||
4 | 113 | 202
|
||||
8 | 212 | 379
|
||||
1 | 30.0 | 49.1
|
||||
2 | 57.5 | 95.1
|
||||
4 | 113 | 185
|
||||
8 | 212 | 353
|
||||
|
||||
## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
|
||||
|
||||
@ -325,11 +329,12 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
* **OS:** Ubuntu 16.04 LTS
|
||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||
* **TensorFlow GitHub hash:** b1e174e
|
||||
* **Benchmark GitHub hash:** 9165a70
|
||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||
//tensorflow/tools/pip_package:build_pip_package`
|
||||
* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
|
||||
* **DataSet:** ImageNet
|
||||
* **Test Date:** April 2017
|
||||
* **Test Date:** May 2017
|
||||
|
||||
The batch size and optimizer used for the tests are listed in the table. In
|
||||
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
||||
@ -343,11 +348,11 @@ Optimizer | sgd | sgd | sgd
|
||||
|
||||
Configuration used for each model.
|
||||
|
||||
Model | variable_update | local_parameter_device
|
||||
----------- | ---------------------- | ----------------------
|
||||
InceptionV3 | distributed_replicated | n/a
|
||||
ResNet-50 | distributed_replicated | n/a
|
||||
ResNet-152 | distributed_replicated | n/a
|
||||
Model | variable_update | local_parameter_device | cross_replica_sync
|
||||
----------- | ---------------------- | ---------------------- | ------------------
|
||||
InceptionV3 | distributed_replicated | n/a | True
|
||||
ResNet-50 | distributed_replicated | n/a | True
|
||||
ResNet-152 | distributed_replicated | n/a | True
|
||||
|
||||
To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
|
||||
ran parameter servers. Equal numbers of parameter servers and work servers were
|
||||
@ -371,11 +376,11 @@ used with the following exceptions:
|
||||
|
||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152
|
||||
---- | ----------- | --------- | ----------
|
||||
1 | 29.7 | 55.0 | 19.8
|
||||
8 | 229 | 410 | 150
|
||||
16 | 459 | 825 | 300
|
||||
32 | 902 | 1468 | 575
|
||||
64 | 1783 | 3051 | 1004
|
||||
1 | 29.7 | 52.4 | 19.4
|
||||
8 | 229 | 378 | 146
|
||||
16 | 459 | 751 | 291
|
||||
32 | 902 | 1388 | 565
|
||||
64 | 1783 | 2744 | 981
|
||||
|
||||
### Other Results
|
||||
|
||||
@ -387,16 +392,16 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152
|
||||
|
||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||
---- | --------------------------- | -------------------------
|
||||
1 | 29.2 | 53.0
|
||||
8 | 219 | 363
|
||||
16 | 427 | 719
|
||||
32 | 820 | 1265
|
||||
64 | 1608 | 2623
|
||||
|
||||
1 | 29.2 | 48.4
|
||||
8 | 219 | 333
|
||||
16 | 427 | 667
|
||||
32 | 820 | 1180
|
||||
64 | 1608 | 2315
|
||||
|
||||
## Methodology
|
||||
|
||||
This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
|
||||
This
|
||||
[script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
|
||||
was run on the various platforms to generate the above results.
|
||||
@{$performance_models$High-Performance Models} details techniques in the script
|
||||
along with examples of how to execute the script.
|
||||
|
@ -370,9 +370,8 @@ def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
|
||||
tf.logging.fatal('File does not exist %s', image_path)
|
||||
image_data = gfile.FastGFile(image_path, 'rb').read()
|
||||
try:
|
||||
bottleneck_values = run_bottleneck_on_image(sess, image_data,
|
||||
jpeg_data_tensor,
|
||||
bottleneck_tensor)
|
||||
bottleneck_values = run_bottleneck_on_image(
|
||||
sess, image_data, jpeg_data_tensor, bottleneck_tensor)
|
||||
except:
|
||||
raise RuntimeError('Error during processing file %s' % image_path)
|
||||
|
||||
|
@ -5583,6 +5583,74 @@ func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Store the input tensor in the state of the current session.
|
||||
//
|
||||
// Arguments:
|
||||
// value: The tensor to be stored.
|
||||
//
|
||||
// Returns The handle for the tensor stored in the session state, represented
|
||||
// as a ResourceHandle object.
|
||||
func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "GetSessionHandleV2",
|
||||
Input: []tf.Input{
|
||||
value,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Adjust the hue of one or more images.
|
||||
//
|
||||
// `images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
// interpretted as channels, and must be three.
|
||||
//
|
||||
// The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
// colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||
// and then remapped back to RGB colorspace.
|
||||
//
|
||||
// Arguments:
|
||||
// images: Images to adjust. At least 3-D.
|
||||
// delta: A float delta to add to the hue.
|
||||
//
|
||||
// Returns The hue-adjusted image or images.
|
||||
func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "AdjustHue",
|
||||
Input: []tf.Input{
|
||||
images, delta,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Restore a Reader to its initial clean state.
|
||||
//
|
||||
// Arguments:
|
||||
// reader_handle: Handle to a Reader.
|
||||
//
|
||||
// Returns the created operation.
|
||||
func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ReaderResetV2",
|
||||
Input: []tf.Input{
|
||||
reader_handle,
|
||||
},
|
||||
}
|
||||
return scope.AddOperation(opspec)
|
||||
}
|
||||
|
||||
// Computes softmax cross entropy cost and gradients to backpropagate.
|
||||
//
|
||||
// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
|
||||
@ -19039,6 +19107,27 @@ func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
|
||||
//
|
||||
// This is the angle \( \theta \in [-\pi, \pi] \) such that
|
||||
// \[ x = r \cos(\theta) \]
|
||||
// and
|
||||
// \[ y = r \sin(\theta) \]
|
||||
// where \(r = \sqrt(x^2 + y^2) \).
|
||||
func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Atan2",
|
||||
Input: []tf.Input{
|
||||
y, x,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
|
||||
//
|
||||
// The regularized incomplete beta integral is defined as:
|
||||
@ -21627,71 +21716,3 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0), op.Output(1)
|
||||
}
|
||||
|
||||
// Store the input tensor in the state of the current session.
|
||||
//
|
||||
// Arguments:
|
||||
// value: The tensor to be stored.
|
||||
//
|
||||
// Returns The handle for the tensor stored in the session state, represented
|
||||
// as a ResourceHandle object.
|
||||
func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "GetSessionHandleV2",
|
||||
Input: []tf.Input{
|
||||
value,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Adjust the hue of one or more images.
|
||||
//
|
||||
// `images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
// interpretted as channels, and must be three.
|
||||
//
|
||||
// The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
// colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||
// and then remapped back to RGB colorspace.
|
||||
//
|
||||
// Arguments:
|
||||
// images: Images to adjust. At least 3-D.
|
||||
// delta: A float delta to add to the hue.
|
||||
//
|
||||
// Returns The hue-adjusted image or images.
|
||||
func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "AdjustHue",
|
||||
Input: []tf.Input{
|
||||
images, delta,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Restore a Reader to its initial clean state.
|
||||
//
|
||||
// Arguments:
|
||||
// reader_handle: Handle to a Reader.
|
||||
//
|
||||
// Returns the created operation.
|
||||
func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ReaderResetV2",
|
||||
Input: []tf.Input{
|
||||
reader_handle,
|
||||
},
|
||||
}
|
||||
return scope.AddOperation(opspec)
|
||||
}
|
||||
|
@ -1 +0,0 @@
|
||||
#include "unsupported/Eigen/CXX11/ThreadPool"
|
@ -129,6 +129,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -656,6 +657,44 @@ def categorical_column_with_vocabulary_list(
|
||||
default_value=default_value)
|
||||
|
||||
|
||||
def categorical_column_with_identity(key, num_buckets, default_value=None):
|
||||
"""A `_CategoricalColumn` that returns identity values.
|
||||
|
||||
Use this when your inputs are integers in the range `[0, num_buckets)`. Values
|
||||
outside this range will result in `default_value` if specified, otherwise it
|
||||
will fail.
|
||||
|
||||
Inputs can be either `Tensor` or `SparseTensor`.
|
||||
```
|
||||
|
||||
Args:
|
||||
key: A unique string identifying the input feature. It is used as the
|
||||
column name and the dictionary key for feature parsing configs, feature
|
||||
`Tensor` objects, and feature columns.
|
||||
num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
|
||||
default_value: If `None`, this column's graph operations will fail for
|
||||
out-of-range inputs. Otherwise, this value must be in the range
|
||||
`[0, num_buckets)`, and will replace inputs in that range.
|
||||
|
||||
Returns:
|
||||
A `_CategoricalColumn` that returns identity values.
|
||||
|
||||
Raises:
|
||||
ValueError: if `num_buckets` is less than one.
|
||||
ValueError: if `default_value` is not in range `[0, num_buckets)`.
|
||||
"""
|
||||
if num_buckets < 1:
|
||||
raise ValueError(
|
||||
'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
|
||||
if (default_value is not None) and (
|
||||
(default_value < 0) or (default_value >= num_buckets)):
|
||||
raise ValueError(
|
||||
'default_value {} not in range [0, {}), column_name {}'.format(
|
||||
default_value, num_buckets, key))
|
||||
return _IdentityCategoricalColumn(
|
||||
key=key, num_buckets=num_buckets, default_value=default_value)
|
||||
|
||||
|
||||
class _FeatureColumn(object):
|
||||
"""Represents a feature column abstraction.
|
||||
|
||||
@ -1384,6 +1423,69 @@ class _VocabularyListCategoricalColumn(
|
||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||
|
||||
|
||||
class _IdentityCategoricalColumn(
|
||||
_CategoricalColumn,
|
||||
collections.namedtuple('_IdentityCategoricalColumn', (
|
||||
'key', 'num_buckets', 'default_value'
|
||||
))):
|
||||
|
||||
"""See `categorical_column_with_identity`."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.key
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
input_tensor = _to_sparse_input(inputs.get(self.key))
|
||||
|
||||
if not input_tensor.dtype.is_integer:
|
||||
raise ValueError(
|
||||
'Invalid input, not integer. key: {} dtype: {}'.format(
|
||||
self.key, input_tensor.dtype))
|
||||
|
||||
values = math_ops.to_int64(input_tensor.values, name='values')
|
||||
num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
|
||||
zero = math_ops.to_int64(0, name='zero')
|
||||
if self.default_value is None:
|
||||
# Fail if values are out-of-range.
|
||||
assert_less = check_ops.assert_less(
|
||||
values, num_buckets, data=(values, num_buckets),
|
||||
name='assert_less_than_num_buckets')
|
||||
assert_greater = check_ops.assert_greater_equal(
|
||||
values, zero, data=(values,),
|
||||
name='assert_greater_or_equal_0')
|
||||
with ops.control_dependencies((assert_less, assert_greater)):
|
||||
values = array_ops.identity(values)
|
||||
else:
|
||||
# Assign default for out-of-range values.
|
||||
values = array_ops.where(
|
||||
math_ops.logical_or(
|
||||
values < zero, values >= num_buckets, name='out_of_range'),
|
||||
array_ops.fill(
|
||||
dims=array_ops.shape(values),
|
||||
value=math_ops.to_int64(self.default_value),
|
||||
name='default_values'),
|
||||
values)
|
||||
|
||||
return sparse_tensor_lib.SparseTensor(
|
||||
indices=input_tensor.indices,
|
||||
values=values,
|
||||
dense_shape=input_tensor.dense_shape)
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
"""Returns number of buckets in this sparse feature."""
|
||||
return self.num_buckets
|
||||
|
||||
def _get_sparse_tensors(
|
||||
self, inputs, weight_collections=None, trainable=None):
|
||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||
|
||||
|
||||
# TODO(zakaria): Move this to embedding_ops and make it public.
|
||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
sparse_ids,
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -1828,5 +1829,198 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
||||
self.assertAllClose(((3.,), (1.,)), predictions.eval())
|
||||
|
||||
|
||||
class IdentityCategoricalColumnTest(test.TestCase):
|
||||
|
||||
def test_constructor(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
self.assertEqual('aaa', column.name)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(3, column._num_buckets)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||
}, column._parse_example_config)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_deep_copy(self):
|
||||
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
for column in (original, copy.deepcopy(original)):
|
||||
self.assertEqual('aaa', column.name)
|
||||
# pylint: disable=protected-access
|
||||
self.assertEqual(3, column._num_buckets)
|
||||
self.assertEqual({
|
||||
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||
}, column._parse_example_config)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_invalid_num_buckets_zero(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
|
||||
fc.categorical_column_with_identity(key='aaa', num_buckets=0)
|
||||
|
||||
def test_invalid_num_buckets_negative(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
|
||||
fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
|
||||
|
||||
def test_invalid_default_value_too_small(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
|
||||
fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3, default_value=-1)
|
||||
|
||||
def test_invalid_default_value_too_big(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
|
||||
fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3, default_value=3)
|
||||
|
||||
def test_invalid_input_dtype(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=('omar', 'stringer', 'marlo'),
|
||||
dense_shape=(2, 2))
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
|
||||
# pylint: disable=protected-access
|
||||
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def test_get_sparse_tensors(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=(0, 1, 0),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
_assert_sparse_tensor_value(
|
||||
self,
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((0, 1, 0), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_dense_input(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||
'aaa': ((0, -1), (1, 0))
|
||||
}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
_assert_sparse_tensor_value(
|
||||
self,
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=np.array((0, 1, 0), dtype=np.int64),
|
||||
dense_shape=(2, 2)),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_with_inputs_too_small(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=(1, -1, 0),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
with self.assertRaisesRegexp(
|
||||
errors.OpError, 'assert_greater_or_equal_0'):
|
||||
id_weight_pair.id_tensor.eval()
|
||||
|
||||
def test_get_sparse_tensors_with_inputs_too_big(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=(1, 99, 0),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
with self.assertRaisesRegexp(
|
||||
errors.OpError, 'assert_less_than_num_buckets'):
|
||||
id_weight_pair.id_tensor.eval()
|
||||
|
||||
def test_get_sparse_tensors_with_default_value(self):
|
||||
column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=4, default_value=3)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=(1, -1, 99),
|
||||
dense_shape=(2, 2))
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
_assert_sparse_tensor_value(
|
||||
self,
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=inputs.indices,
|
||||
values=np.array((1, 3, 3), dtype=np.int64),
|
||||
dense_shape=inputs.dense_shape),
|
||||
id_weight_pair.id_tensor.eval())
|
||||
|
||||
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
|
||||
column = fc.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=4, default_value=3)
|
||||
input_indices = array_ops.placeholder(dtype=dtypes.int64)
|
||||
input_values = array_ops.placeholder(dtype=dtypes.int32)
|
||||
input_shape = array_ops.placeholder(dtype=dtypes.int64)
|
||||
inputs = sparse_tensor.SparseTensorValue(
|
||||
indices=input_indices,
|
||||
values=input_values,
|
||||
dense_shape=input_shape)
|
||||
# pylint: disable=protected-access
|
||||
id_weight_pair = column._get_sparse_tensors(
|
||||
fc._LazyBuilder({'aaa': inputs}))
|
||||
# pylint: enable=protected-access
|
||||
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||
with _initialized_session():
|
||||
_assert_sparse_tensor_value(
|
||||
self,
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
|
||||
values=np.array((1, 3, 3), dtype=np.int64),
|
||||
dense_shape=np.array((2, 2), dtype=np.int64)),
|
||||
id_weight_pair.id_tensor.eval(feed_dict={
|
||||
input_indices: ((0, 0), (1, 0), (1, 1)),
|
||||
input_values: (1, -1, 99),
|
||||
input_shape: (2, 2),
|
||||
}))
|
||||
|
||||
def test_make_linear_model(self):
|
||||
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||
self.assertEqual(3, column._num_buckets)
|
||||
with ops.Graph().as_default():
|
||||
predictions = fc.make_linear_model({
|
||||
column.name: sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=(0, 2, 1),
|
||||
dense_shape=(2, 2))
|
||||
}, (column,))
|
||||
bias = get_linear_model_bias()
|
||||
weight_var = get_linear_model_column_var(column)
|
||||
with _initialized_session():
|
||||
self.assertAllClose((0.,), bias.eval())
|
||||
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
|
||||
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||
weight_var.assign(((1.,), (2.,), (3.,))).eval()
|
||||
# weight_var[0] = 1
|
||||
# weight_var[2] + weight_var[1] = 3+2 = 5
|
||||
self.assertAllClose(((1.,), (5.,)), predictions.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -113,18 +113,19 @@ def _add_op_node(op, func, input_dict):
|
||||
node_def = func.node_def[-1]
|
||||
for i in range(len(node_def.input)):
|
||||
if not node_def.input[i].startswith("^"):
|
||||
assert node_def.input[i] in input_dict, (
|
||||
"%s missing from %s" % (node_def.input[i], input_dict.items()))
|
||||
assert node_def.input[i] in input_dict, ("%s missing from %s" %
|
||||
(node_def.input[i],
|
||||
input_dict.items()))
|
||||
node_def.input[i] = input_dict[node_def.input[i]]
|
||||
|
||||
|
||||
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||
def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||||
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||||
|
||||
This method creates a [`FunctionDef`](
|
||||
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
|
||||
protocol buffer that contains all the ops present in the graph. The
|
||||
graph effectively becomes the body of the function.
|
||||
protocol buffer that contains all the ops in `operations`. The
|
||||
operations become the body of the function.
|
||||
|
||||
The arguments `inputs` and `outputs` will be listed as the inputs
|
||||
and outputs tensors of the function. They must be lists of
|
||||
@ -132,6 +133,8 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||
|
||||
Args:
|
||||
graph: Graph.
|
||||
operations: the operations to put in the function. Must be a subset of
|
||||
the operations in the graph.
|
||||
inputs: List of tensors. Inputs to the function.
|
||||
outputs: List of tensors. Outputs of the function.
|
||||
out_names: Optional list of string names for the outputs.
|
||||
@ -145,12 +148,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||
func = function_pb2.FunctionDef()
|
||||
func.signature.name = "_"
|
||||
used_names = set()
|
||||
func.signature.input_arg.extend([_tensor_to_argdef(i, used_names=used_names)
|
||||
for i in inputs])
|
||||
func.signature.input_arg.extend(
|
||||
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
|
||||
if out_names is None:
|
||||
used_names = set()
|
||||
func.signature.output_arg.extend([
|
||||
_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
||||
elif len(outputs) != len(out_names):
|
||||
raise ValueError(
|
||||
"Length of out_names (%d) does not match number of outputs (%d): %s" %
|
||||
@ -159,12 +162,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
||||
raise ValueError(
|
||||
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
|
||||
else:
|
||||
func.signature.output_arg.extend([
|
||||
_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||
func.signature.output_arg.extend(
|
||||
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||
func_arg_placeholders = set([i.name for i in inputs])
|
||||
input_dict = _create_input_dict(graph, func_arg_placeholders)
|
||||
|
||||
for op in graph.get_operations():
|
||||
for op in operations:
|
||||
if _is_in_placeholders(op, func_arg_placeholders):
|
||||
continue
|
||||
_add_op_node(op, func, input_dict)
|
||||
@ -295,7 +298,8 @@ class _FuncGraph(ops.Graph):
|
||||
self.extra_args = []
|
||||
self.extra_vars = []
|
||||
|
||||
def getvar(self,
|
||||
def getvar(
|
||||
self,
|
||||
getter,
|
||||
name,
|
||||
shape=None,
|
||||
@ -538,20 +542,23 @@ class _DefinedFunction(object):
|
||||
|
||||
# Build the FunctionDef
|
||||
self._definition = _graph_to_function_def(
|
||||
temp_graph, inputs, outputs, out_names=self._out_names)
|
||||
temp_graph,
|
||||
temp_graph.get_operations(),
|
||||
inputs,
|
||||
outputs,
|
||||
out_names=self._out_names)
|
||||
|
||||
# Extra kwargs are treated as attrs on the function def.
|
||||
sig_pre_func_name = self._func_name or _get_func_name(self._func)
|
||||
kwargs_attr = _parse_kwargs_as_attrs(
|
||||
sig_pre_func_name, **self._extra_kwargs)
|
||||
kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
|
||||
**self._extra_kwargs)
|
||||
for k in kwargs_attr:
|
||||
self._definition.attr[k].CopyFrom(kwargs_attr[k])
|
||||
|
||||
# Hash the definition and its dependencies.
|
||||
self._hash_str = self._create_hash_str(
|
||||
self._definition.signature.input_arg,
|
||||
self._definition.signature.output_arg,
|
||||
self._definition.node_def)
|
||||
self._definition.signature.output_arg, self._definition.node_def)
|
||||
|
||||
# Finally, we decide the function name to use. If not specified,
|
||||
# make up something which is almost certainly unique (but deterministic).
|
||||
@ -658,8 +665,8 @@ def _from_definition(fdef, grad_func=None):
|
||||
# have access to such a callable here).
|
||||
func = None
|
||||
argnames = [arg.name for arg in fdef.signature.input_arg]
|
||||
input_types = tuple(dtypes.as_dtype(arg.type)
|
||||
for arg in fdef.signature.input_arg)
|
||||
input_types = tuple(
|
||||
dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
|
||||
func_name = fdef.signature.name
|
||||
# Note: FunctionDefs do not include python gradient functions, so if the
|
||||
# original _DefinedFunction included one it will not be reflected here.
|
||||
@ -675,8 +682,7 @@ def _from_definition(fdef, grad_func=None):
|
||||
result._extra_inputs = []
|
||||
result._hash_str = result._create_hash_str(
|
||||
result._definition.signature.input_arg,
|
||||
result._definition.signature.output_arg,
|
||||
result._definition.node_def)
|
||||
result._definition.signature.output_arg, result._definition.node_def)
|
||||
# pylint: enable=protected-access
|
||||
return result
|
||||
|
||||
@ -696,7 +702,8 @@ def _from_library(lib):
|
||||
Raises:
|
||||
ValueError: `lib` is invalid
|
||||
"""
|
||||
if not lib.function and not lib.gradient: return []
|
||||
if not lib.function and not lib.gradient:
|
||||
return []
|
||||
|
||||
# function name -> FunctionDef proto
|
||||
funcs = {fdef.signature.name: fdef for fdef in lib.function}
|
||||
@ -720,8 +727,9 @@ def _from_library(lib):
|
||||
grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
|
||||
|
||||
# Start with functions without gradients
|
||||
ready = [fdef for fdef in lib.function
|
||||
if func_to_grad[fdef.signature.name] is None]
|
||||
ready = [
|
||||
fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
|
||||
]
|
||||
if not ready:
|
||||
raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
|
||||
+ str(lib))
|
||||
@ -733,7 +741,8 @@ def _from_library(lib):
|
||||
name = fdef.signature.name
|
||||
|
||||
grad = initialized.get(func_to_grad[name])
|
||||
if func_to_grad[name]: assert grad
|
||||
if func_to_grad[name]:
|
||||
assert grad
|
||||
defined_func = _from_definition(fdef, grad_func=grad)
|
||||
initialized[name] = defined_func
|
||||
|
||||
@ -835,8 +844,13 @@ class _OverloadedFunction(object):
|
||||
name = self._func_name
|
||||
if name is not None:
|
||||
name = "_".join([name, key])
|
||||
defined = _DefinedFunction(self._func, self._argnames, input_types, name,
|
||||
None, self._python_grad_func,
|
||||
defined = _DefinedFunction(
|
||||
self._func,
|
||||
self._argnames,
|
||||
input_types,
|
||||
name,
|
||||
None,
|
||||
self._python_grad_func,
|
||||
out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
_ = defined.name # Fully instantiate the function definition.
|
||||
@ -849,8 +863,8 @@ class _OverloadedFunction(object):
|
||||
for _ in defined.definition.signature.output_arg
|
||||
]
|
||||
# pylint: disable=protected-access
|
||||
defined._grad_func = self._grad_func.instantiate(input_types +
|
||||
output_types)
|
||||
defined._grad_func = self._grad_func.instantiate(
|
||||
input_types + output_types)
|
||||
# pylint: enable=protected-access
|
||||
self._overload[key] = defined
|
||||
return defined
|
||||
@ -981,22 +995,36 @@ class Defun(object):
|
||||
raise ValueError(
|
||||
"The function has fewer arguments than the number of specified "
|
||||
"input types.")
|
||||
return _DefinedFunction(func, argnames, self._input_types,
|
||||
self._func_name, self._grad_func,
|
||||
return _DefinedFunction(
|
||||
func,
|
||||
argnames,
|
||||
self._input_types,
|
||||
self._func_name,
|
||||
self._grad_func,
|
||||
self._python_grad_func,
|
||||
out_names=self._out_names, **self._extra_kwargs)
|
||||
out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
# 'func' expects no arguments and input types is an empty list.
|
||||
if min_args == 0 and max_args == 0:
|
||||
return _DefinedFunction(func, [], [], self._func_name, self._grad_func,
|
||||
return _DefinedFunction(
|
||||
func, [], [],
|
||||
self._func_name,
|
||||
self._grad_func,
|
||||
self._python_grad_func,
|
||||
out_names=self._out_names, **self._extra_kwargs)
|
||||
out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
# Input types are unknown. It's an overloaded function and hence
|
||||
# its definition needs to be deferred until it's called.
|
||||
return _OverloadedFunction(func, argnames, self._func_name, self._grad_func,
|
||||
return _OverloadedFunction(
|
||||
func,
|
||||
argnames,
|
||||
self._func_name,
|
||||
self._grad_func,
|
||||
self._python_grad_func,
|
||||
out_names=self._out_names, **self._extra_kwargs)
|
||||
out_names=self._out_names,
|
||||
**self._extra_kwargs)
|
||||
|
||||
|
||||
class Declare(object):
|
||||
@ -1039,8 +1067,10 @@ class Declare(object):
|
||||
names = [n for n, t in args]
|
||||
if len(names) != len(set(names)):
|
||||
raise ValueError("Expected names to all be unique: %s" % str(names))
|
||||
return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
|
||||
for n, t in args]
|
||||
return [
|
||||
op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
|
||||
for n, t in args
|
||||
]
|
||||
|
||||
self._sig.input_arg.extend(_to_argdef_list(inputs))
|
||||
self._sig.output_arg.extend(_to_argdef_list(outputs))
|
||||
|
@ -1106,16 +1106,18 @@ class BinaryOpTest(test.TestCase):
|
||||
|
||||
def testAtan2SpecialValues(self):
|
||||
x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
|
||||
(1.2345, float('inf')), (1.2345, -float('inf')),
|
||||
(-4.321, float('inf')), (-4.125, -float('inf')),
|
||||
(float('inf'), float('inf')), (float('inf'), -float('inf')),
|
||||
(-float('inf'), float('inf')), (-float('inf'), -float('inf')))
|
||||
(1.2345, float("inf")), (1.2345, -float("inf")),
|
||||
(-4.321, float("inf")), (-4.125, -float("inf")),
|
||||
(float("inf"), float("inf")), (float("inf"), -float("inf")),
|
||||
(-float("inf"), float("inf")), (-float("inf"),
|
||||
-float("inf")))
|
||||
for dtype in np.float32, np.float64:
|
||||
x1 = np.array(x1l).astype(dtype)
|
||||
x2 = np.array(x2l).astype(dtype)
|
||||
self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
|
||||
self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
|
||||
|
||||
|
||||
class ComparisonOpTest(test.TestCase):
|
||||
|
||||
def _compareScalar(self, func, x, y, dtype):
|
||||
|
@ -19,58 +19,65 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow import Tensor
|
||||
from tensorflow import register_tensor_conversion_function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test as test_lib
|
||||
|
||||
|
||||
class TensorPriorityTest(test_lib.TestCase):
|
||||
|
||||
def testSupportedRhsWithoutDelegation(self):
|
||||
|
||||
class NumpyArraySubclass(np.ndarray):
|
||||
pass
|
||||
supported_rhs_without_delegation = (
|
||||
3,
|
||||
3.0,
|
||||
[1.0, 2.0],
|
||||
np.array([1.0, 2.0]),
|
||||
NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0])),
|
||||
|
||||
supported_rhs_without_delegation = (3, 3.0, [1.0, 2.0], np.array(
|
||||
[1.0, 2.0]), NumpyArraySubclass(
|
||||
shape=(1, 2), buffer=np.array([1.0, 2.0])),
|
||||
ops.convert_to_tensor([[1.0, 2.0]]))
|
||||
for rhs in supported_rhs_without_delegation:
|
||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||
res = tensor + rhs
|
||||
self.assertIsInstance(res, Tensor)
|
||||
self.assertIsInstance(res, ops.Tensor)
|
||||
|
||||
def testUnsupportedRhsWithoutDelegation(self):
|
||||
|
||||
class WithoutReverseAdd(object):
|
||||
pass
|
||||
|
||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||
rhs = WithoutReverseAdd()
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
TypeError, lambda e: "Expected float" in str(e)):
|
||||
res = tensor + rhs
|
||||
# pylint: disable=pointless-statement
|
||||
tensor + rhs
|
||||
|
||||
def testUnsupportedRhsWithDelegation(self):
|
||||
|
||||
class WithReverseAdd(object):
|
||||
|
||||
def __radd__(self, lhs):
|
||||
return "Works!"
|
||||
|
||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||
rhs = WithReverseAdd()
|
||||
res = tensor + rhs
|
||||
self.assertEqual(res, "Works!")
|
||||
|
||||
def testFullDelegationControlUsingRegistry(self):
|
||||
|
||||
class NumpyArraySubclass(np.ndarray):
|
||||
|
||||
def __radd__(self, lhs):
|
||||
return "Works!"
|
||||
|
||||
def raise_to_delegate(value, dtype=None, name=None, as_ref=False):
|
||||
del value, dtype, name, as_ref # Unused.
|
||||
raise TypeError
|
||||
register_tensor_conversion_function(NumpyArraySubclass, raise_to_delegate,
|
||||
priority=0)
|
||||
|
||||
ops.register_tensor_conversion_function(
|
||||
NumpyArraySubclass, raise_to_delegate, priority=0)
|
||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||
rhs = NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0]))
|
||||
rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0]))
|
||||
res = tensor + rhs
|
||||
self.assertEqual(res, "Works!")
|
||||
|
||||
|
@ -1109,10 +1109,10 @@ class Conv2DTranspose(Conv2D):
|
||||
# Infer the static output shape:
|
||||
out_shape = inputs.get_shape().as_list()
|
||||
out_shape[c_axis] = self.filters
|
||||
out_shape[h_axis] = utils.get_deconv_dim(
|
||||
out_shape[h_axis], stride_h, kernel_h, self.padding)
|
||||
out_shape[w_axis] = utils.get_deconv_dim(
|
||||
out_shape[w_axis], stride_w, kernel_w, self.padding)
|
||||
out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
|
||||
kernel_h, self.padding)
|
||||
out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
|
||||
kernel_w, self.padding)
|
||||
outputs.set_shape(out_shape)
|
||||
|
||||
if self.bias:
|
||||
@ -1240,7 +1240,8 @@ class Conv3DTranspose(Conv3D):
|
||||
name: A string, the name of the layer.
|
||||
"""
|
||||
|
||||
def __init__(self, filters,
|
||||
def __init__(self,
|
||||
filters,
|
||||
kernel_size,
|
||||
strides=(1, 1, 1),
|
||||
padding='valid',
|
||||
@ -1269,12 +1270,13 @@ class Conv3DTranspose(Conv3D):
|
||||
bias_regularizer=bias_regularizer,
|
||||
activity_regularizer=activity_regularizer,
|
||||
trainable=trainable,
|
||||
name=name, **kwargs)
|
||||
name=name,
|
||||
**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
if len(input_shape) != 5:
|
||||
raise ValueError('Inputs should have rank 5, ' +
|
||||
'received input shape:', str(input_shape))
|
||||
raise ValueError('Inputs should have rank 5, received input shape:',
|
||||
str(input_shape))
|
||||
if self.data_format == 'channels_first':
|
||||
channel_axis = 1
|
||||
else:
|
||||
@ -1285,14 +1287,16 @@ class Conv3DTranspose(Conv3D):
|
||||
input_dim = input_shape[channel_axis]
|
||||
kernel_shape = self.kernel_size + (self.filters, input_dim)
|
||||
|
||||
self.kernel = self.add_variable('kernel',
|
||||
self.kernel = self.add_variable(
|
||||
'kernel',
|
||||
shape=kernel_shape,
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
trainable=True,
|
||||
dtype=self.dtype)
|
||||
if self.use_bias:
|
||||
self.bias = self.add_variable('bias',
|
||||
self.bias = self.add_variable(
|
||||
'bias',
|
||||
shape=(self.filters,),
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
@ -1300,7 +1304,6 @@ class Conv3DTranspose(Conv3D):
|
||||
dtype=self.dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
inputs_shape = array_ops.shape(inputs)
|
||||
@ -1343,26 +1346,26 @@ class Conv3DTranspose(Conv3D):
|
||||
# Infer the static output shape:
|
||||
out_shape = inputs.get_shape().as_list()
|
||||
out_shape[c_axis] = self.filters
|
||||
out_shape[d_axis] = utils.get_deconv_dim(
|
||||
out_shape[d_axis], stride_d, kernel_d, self.padding)
|
||||
out_shape[h_axis] = utils.get_deconv_dim(
|
||||
out_shape[h_axis], stride_h, kernel_h, self.padding)
|
||||
out_shape[w_axis] = utils.get_deconv_dim(
|
||||
out_shape[w_axis], stride_w, kernel_w, self.padding)
|
||||
out_shape[d_axis] = utils.get_deconv_dim(out_shape[d_axis], stride_d,
|
||||
kernel_d, self.padding)
|
||||
out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
|
||||
kernel_h, self.padding)
|
||||
out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
|
||||
kernel_w, self.padding)
|
||||
outputs.set_shape(out_shape)
|
||||
|
||||
if self.bias:
|
||||
outputs_shape = outputs.shape.as_list()
|
||||
if self.data_format == 'channels_first':
|
||||
outputs_4d = array_ops.reshape(outputs,
|
||||
[outputs_shape[0], outputs_shape[1],
|
||||
outputs_shape[2] * outputs_shape[3],
|
||||
outputs_shape[4]])
|
||||
outputs_4d = array_ops.reshape(outputs, [
|
||||
outputs_shape[0], outputs_shape[1],
|
||||
outputs_shape[2] * outputs_shape[3], outputs_shape[4]
|
||||
])
|
||||
else:
|
||||
outputs_4d = array_ops.reshape(outputs,
|
||||
[outputs_shape[0],
|
||||
outputs_shape[1] * outputs_shape[2],
|
||||
outputs_shape[3], outputs_shape[4]])
|
||||
outputs_4d = array_ops.reshape(outputs, [
|
||||
outputs_shape[0], outputs_shape[1] * outputs_shape[2],
|
||||
outputs_shape[3], outputs_shape[4]
|
||||
])
|
||||
outputs_4d = nn.bias_add(
|
||||
outputs_4d,
|
||||
self.bias,
|
||||
|
@ -715,8 +715,8 @@ class Conv3DTransposeTest(test.TestCase):
|
||||
layer = conv_layers.Conv3DTranspose(
|
||||
32, volumes.get_shape()[1:4], padding='same')
|
||||
output = layer.apply(volumes)
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, depth, height,
|
||||
width, 32])
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[5, depth, height, width, 32])
|
||||
|
||||
def testCreateConv3DTransposeWithStrides(self):
|
||||
depth, height, width = 4, 6, 8
|
||||
@ -729,8 +729,7 @@ class Conv3DTransposeTest(test.TestCase):
|
||||
[5, depth * 2, height * 2, width * 2, 4])
|
||||
|
||||
# Test strides integer.
|
||||
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2,
|
||||
padding='same')
|
||||
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, padding='same')
|
||||
output = layer.apply(volumes)
|
||||
self.assertListEqual(output.get_shape().as_list(),
|
||||
[5, depth * 2, height * 2, width * 2, 4])
|
||||
@ -779,14 +778,14 @@ class Conv3DTransposeTest(test.TestCase):
|
||||
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
|
||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
||||
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
|
||||
conv_layers.conv3d_transpose(
|
||||
volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
|
||||
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||
|
||||
def testFunctionalConv3DTransposeReuseFromScope(self):
|
||||
with variable_scope.variable_scope('scope'):
|
||||
depth, height, width = 5, 7, 9
|
||||
volumes = random_ops.random_uniform((5, depth, height, width, 32),
|
||||
seed=1)
|
||||
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
|
||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
||||
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||
with variable_scope.variable_scope('scope', reuse=True):
|
||||
@ -798,8 +797,8 @@ class Conv3DTransposeTest(test.TestCase):
|
||||
with variable_scope.variable_scope(
|
||||
'scope', initializer=init_ops.ones_initializer()):
|
||||
depth, height, width = 5, 7, 9
|
||||
volumes = random_ops.random_uniform((5, depth, height, width, 32),
|
||||
seed=1)
|
||||
volumes = random_ops.random_uniform(
|
||||
(5, depth, height, width, 32), seed=1)
|
||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
||||
weights = variables.trainable_variables()
|
||||
# Check the names of weights in order.
|
||||
|
@ -205,7 +205,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
||||
`decoded.shape`: Shape vector, size `(2)`.
|
||||
The shape values are: `[batch_size, max_decoded_length]`
|
||||
neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
|
||||
sequence found, the negative of the sum of the greatest logit at each timeframe.
|
||||
sequence found, the negative of the sum of the greatest logit at each
|
||||
timeframe.
|
||||
"""
|
||||
outputs = gen_ctc_ops._ctc_greedy_decoder(
|
||||
inputs, sequence_length, merge_repeated=merge_repeated)
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
@ -964,8 +964,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
|
||||
ValueError: If input/output depth does not match `filters`' shape, or if
|
||||
padding is other than `'VALID'` or `'SAME'`.
|
||||
"""
|
||||
return convolution(input=value, filter=filters, padding=padding,
|
||||
dilation_rate=np.broadcast_to(rate, (2, )), name=name)
|
||||
return convolution(
|
||||
input=value,
|
||||
filter=filters,
|
||||
padding=padding,
|
||||
dilation_rate=np.broadcast_to(rate, (2,)),
|
||||
name=name)
|
||||
|
||||
|
||||
def conv2d_transpose(value,
|
||||
@ -1231,8 +1235,8 @@ def conv3d_transpose(value,
|
||||
axis = 1 if data_format == "NCDHW" else 4
|
||||
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
|
||||
raise ValueError("input channels does not match filter's input channels, "
|
||||
"{} != {}".format(value.get_shape()[axis], filter.get_shape(
|
||||
)[4]))
|
||||
"{} != {}".format(value.get_shape()[axis],
|
||||
filter.get_shape()[4]))
|
||||
|
||||
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
|
||||
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):
|
||||
|
@ -195,7 +195,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
Raises:
|
||||
RuntimeError: MetaGraphDef associated with the tags cannot be found.
|
||||
"""
|
||||
# Build the SavedModel protocol buffer and find the requested meta graph def.
|
||||
with sess.graph.as_default():
|
||||
# Build the SavedModel protocol buffer and find requested meta graph def.
|
||||
saved_model = _parse_saved_model(export_dir)
|
||||
found_match = False
|
||||
for meta_graph_def in saved_model.meta_graphs:
|
||||
@ -234,7 +235,7 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
else:
|
||||
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
|
||||
if legacy_init_op_tensor is not None:
|
||||
sess.run(fetches=[legacy_init_op_tensor],
|
||||
feed_dict=asset_tensors_dictionary)
|
||||
sess.run(
|
||||
fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary)
|
||||
|
||||
return meta_graph_def_to_load
|
||||
|
@ -151,6 +151,27 @@ class SavedModelTest(test.TestCase):
|
||||
constants.SAVED_MODEL_FILENAME_PBTXT):
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
|
||||
def testVerifySessionGraphUsage(self):
|
||||
export_dir = os.path.join(test.get_temp_dir(),
|
||||
"test_verify_session_graph_usage")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
# Build a session and supply it to the load operation.
|
||||
sess = session.Session(graph=ops.Graph())
|
||||
loader.load(sess, [tag_constants.TRAINING], export_dir)
|
||||
|
||||
# Check the variable within the scope of the session and its graph.
|
||||
with sess:
|
||||
self.assertEqual(
|
||||
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||
|
||||
def testSequence(self):
|
||||
export_dir = os.path.join(test.get_temp_dir(), "test_sequence")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
@ -12,12 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ================================
|
||||
"""Imports a protobuf model as a graph in Tensorboard."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.summary import summary
|
||||
|
||||
|
||||
def import_to_tensorboard(model_dir, log_dir):
|
||||
@ -32,13 +38,13 @@ def import_to_tensorboard(model_dir, log_dir):
|
||||
Launch Tensorboard by pointing it to the log directory.
|
||||
View your imported `.pb` model as a graph.
|
||||
"""
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
with tf.gfile.FastGFile(model_dir, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
with gfile.FastGFile(model_dir, "rb") as f:
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
g_in = tf.import_graph_def(graph_def)
|
||||
importer.import_graph_def(graph_def)
|
||||
|
||||
pb_visual_writer = tf.summary.FileWriter(log_dir)
|
||||
pb_visual_writer = summary.FileWriter(log_dir)
|
||||
pb_visual_writer.add_graph(sess.graph)
|
||||
print("Model Imported. Visualize by running: "
|
||||
"> tensorboard --logdir={}".format(log_dir))
|
||||
|
@ -504,7 +504,14 @@ def run(args):
|
||||
|
||||
Args:
|
||||
args: A namespace parsed from command line.
|
||||
|
||||
Raises:
|
||||
AttributeError: An error when neither --inputs nor --input_exprs is passed
|
||||
to run command.
|
||||
"""
|
||||
if not args.inputs and not args.input_exprs:
|
||||
raise AttributeError(
|
||||
'At least one of --inputs and --input_exprs must be required')
|
||||
tensor_key_feed_dict = load_inputs_from_input_arg_string(
|
||||
args.inputs, args.input_exprs)
|
||||
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
|
||||
@ -629,8 +636,6 @@ def create_parser():
|
||||
def main():
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
if not args.inputs and not args.input_exprs:
|
||||
args.error('At least one of --inputs and --input_exprs is required')
|
||||
args.func(args)
|
||||
|
||||
|
||||
|
@ -409,6 +409,16 @@ Method name is: tensorflow/serving/predict"""
|
||||
with self.assertRaises(RuntimeError):
|
||||
saved_model_cli.run(args)
|
||||
|
||||
def testRunCommandInputNotGivenError(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
args = self.parser.parse_args([
|
||||
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||
'serving_default'
|
||||
])
|
||||
with self.assertRaises(AttributeError):
|
||||
saved_model_cli.run(args)
|
||||
|
||||
def testRunCommandWithDebuggerEnabled(self):
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
|
@ -210,9 +210,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
||||
else:
|
||||
var_name = ",".join([v.name for v in var])
|
||||
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
|
||||
logging.info("Initialize variable %s from checkpoint %s with %s" % (
|
||||
var_name, ckpt_dir_or_file, tensor_name_in_ckpt
|
||||
))
|
||||
logging.info("Initialize variable %s from checkpoint %s with %s",
|
||||
var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
|
||||
else:
|
||||
scopes = ""
|
||||
# TODO(vihanjain): Support list of 'current_var_or_name' here.
|
||||
@ -250,9 +249,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
||||
if var is None:
|
||||
var = _collect_partitioned_variable(var_name, store_vars)
|
||||
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
|
||||
logging.info("Initialize variable %s from checkpoint %s with %s" % (
|
||||
var_name, ckpt_dir_or_file, full_tensor_name
|
||||
))
|
||||
logging.info("Initialize variable %s from checkpoint %s with %s",
|
||||
var_name, ckpt_dir_or_file, full_tensor_name)
|
||||
|
||||
|
||||
def _get_checkpoint_filename(ckpt_dir_or_file):
|
||||
|
@ -935,11 +935,11 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
||||
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
|
||||
except errors.OpError as e:
|
||||
# It's ok if the file cannot be read
|
||||
logging.warning("%s: %s" % (type(e).__name__, e))
|
||||
logging.warning("%s: %s", type(e).__name__, e)
|
||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||
return None
|
||||
except text_format.ParseError as e:
|
||||
logging.warning("%s: %s" % (type(e).__name__, e))
|
||||
logging.warning("%s: %s", type(e).__name__, e)
|
||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||
return None
|
||||
finally:
|
||||
|
@ -230,13 +230,15 @@ class TensorboardServerTest(test.TestCase):
|
||||
def testScalars(self):
|
||||
"""Test the format of /data/scalars."""
|
||||
data = self._getJson('/data/scalars?run=run1&tag=simple_values')
|
||||
self.assertEqual(len(data),self._SCALAR_COUNT)
|
||||
self.assertEqual(len(data), self._SCALAR_COUNT)
|
||||
|
||||
def testScalarsCsv(self):
|
||||
"""Test the csv format of /data/scalars."""
|
||||
data = self._get('/data/scalars?run=run1&tag=simple_values&format=csv').read()
|
||||
data = self._get(
|
||||
'/data/scalars?run=run1&tag=simple_values&format=csv').read()
|
||||
line_count = data.count('\n')
|
||||
self.assertEqual(line_count,self._SCALAR_COUNT + 1) # include 1 more line for header
|
||||
self.assertEqual(line_count,
|
||||
self._SCALAR_COUNT + 1) # include 1 more line for header
|
||||
|
||||
def testHistograms(self):
|
||||
"""Test the format of /data/histograms."""
|
||||
|
63
tensorflow/tensorboard/components/tf_audio_dashboard/BUILD
Normal file
63
tensorflow/tensorboard/components/tf_audio_dashboard/BUILD
Normal file
@ -0,0 +1,63 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
webfiles(
|
||||
name = "tf_audio_dashboard",
|
||||
srcs = [
|
||||
"tf-audio-dashboard.html",
|
||||
"tf-audio-grid.html",
|
||||
"tf-audio-loader.html",
|
||||
],
|
||||
path = "/tf-audio-dashboard",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_backend",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common",
|
||||
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||
"@org_polymer",
|
||||
"@org_polymer_paper_icon_button",
|
||||
"@org_polymer_paper_slider",
|
||||
"@org_polymer_paper_spinner",
|
||||
"@org_polymer_paper_styles",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = [
|
||||
"tf-audio-dashboard.html",
|
||||
"tf-audio-grid.html",
|
||||
"tf-audio-loader.html",
|
||||
],
|
||||
destdir = "tf-audio-dashboard",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:tf_imports",
|
||||
"//tensorflow/tensorboard/components/tf_backend:legacy",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
|
||||
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# This is needed: components/BUILD seeks a legacy_ts rule in this package.
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = [],
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||
)
|
@ -0,0 +1,26 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# bazel run //third_party/tensorflow/tensorboard/components/tf_audio_dashboard/demo
|
||||
webfiles(
|
||||
name = "demo",
|
||||
srcs = ["index.html"],
|
||||
path = "/tf-audio-dashboard/demo",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_audio_dashboard",
|
||||
"//tensorflow/tensorboard/components/tf_audio_dashboard/demo/data",
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"@org_polymer_iron_demo_helpers",
|
||||
"@org_polymer_paper_styles",
|
||||
"@org_polymer_webcomponentsjs",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
@ -0,0 +1,17 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
webfiles(
|
||||
name = "data",
|
||||
srcs = glob(["*"]),
|
||||
path = "/tf-audio-dashboard/demo/data",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
@ -107,6 +107,8 @@ future for loading older clips.
|
||||
</template>
|
||||
</template>
|
||||
<script>
|
||||
"use strict";
|
||||
|
||||
Polymer({
|
||||
is: "tf-audio-loader",
|
||||
properties: {
|
||||
|
81
tensorflow/tensorboard/components/tf_backend/BUILD
Normal file
81
tensorflow/tensorboard/components/tf_backend/BUILD
Normal file
@ -0,0 +1,81 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# TODO(dandelion): Add webfiles support for the test code.
|
||||
|
||||
webfiles(
|
||||
name = "tf_backend",
|
||||
srcs = [
|
||||
"tf-backend.html",
|
||||
":ts",
|
||||
],
|
||||
path = "/tf-backend",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||
"//tensorflow/tensorboard/components/vz_sorting",
|
||||
"@org_polymer",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_typescript_genrule(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"backend.ts",
|
||||
"behavior.ts",
|
||||
"requestManager.ts",
|
||||
"router.ts",
|
||||
"urlPathHelpers.ts",
|
||||
],
|
||||
typings = [
|
||||
"@org_definitelytyped//:d3.d.ts",
|
||||
"@org_definitelytyped//:lodash.d.ts",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:ts_typings",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = [
|
||||
"tf-backend.html",
|
||||
":legacy_ts",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
destdir = "tf-backend",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:tf_imports",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = [
|
||||
"backend.ts",
|
||||
"behavior.ts",
|
||||
"requestManager.ts",
|
||||
"router.ts",
|
||||
"urlPathHelpers.ts",
|
||||
],
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:common_deps",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
|
||||
],
|
||||
)
|
45
tensorflow/tensorboard/components/tf_backend_d3v4/BUILD
Normal file
45
tensorflow/tensorboard/components/tf_backend_d3v4/BUILD
Normal file
@ -0,0 +1,45 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"backend.ts",
|
||||
"behavior.ts",
|
||||
"requestManager.ts",
|
||||
"router.ts",
|
||||
"urlPathHelpers.ts",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
|
||||
"//third_party/javascript/node_modules/typescript:es2015.promise",
|
||||
"//third_party/javascript/typings/chai",
|
||||
"//third_party/javascript/typings/d3_v4:bundle",
|
||||
"//third_party/javascript/typings/lodash",
|
||||
"//third_party/javascript/typings/mocha",
|
||||
"//third_party/javascript/typings/polymer:polymer_without_externs",
|
||||
"//third_party/javascript/typings/sinon",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(dandelion): Add runners for these tests
|
||||
tensorboard_ts_library(
|
||||
name = "tests",
|
||||
srcs = [
|
||||
"backendTests.ts",
|
||||
"behaviorTests.ts",
|
||||
"requestManagerTests.ts",
|
||||
],
|
||||
deps = [
|
||||
":ts",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
65
tensorflow/tensorboard/components/tf_color_scale/BUILD
Normal file
65
tensorflow/tensorboard/components/tf_color_scale/BUILD
Normal file
@ -0,0 +1,65 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# TODO(dandelion): Add webfiles support for the test code.
|
||||
|
||||
webfiles(
|
||||
name = "tf_color_scale",
|
||||
srcs = [
|
||||
"tf-color-scale.html",
|
||||
":ts",
|
||||
],
|
||||
path = "/tf-color-scale",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"@org_polymer",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_typescript_genrule(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"colorScale.ts",
|
||||
"palettes.ts",
|
||||
],
|
||||
typings = ["@org_definitelytyped//:d3.d.ts"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = [
|
||||
"tf-color-scale.html",
|
||||
":legacy_ts",
|
||||
],
|
||||
destdir = "tf-color-scale",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:tf_imports",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = [
|
||||
"colorScale.ts",
|
||||
"palettes.ts",
|
||||
],
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||
)
|
26
tensorflow/tensorboard/components/tf_color_scale/demo/BUILD
Normal file
26
tensorflow/tensorboard/components/tf_color_scale/demo/BUILD
Normal file
@ -0,0 +1,26 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# bazel run //third_party/tensorflow/tensorboard/components/tf_color_scale/demo
|
||||
webfiles(
|
||||
name = "demo",
|
||||
srcs = ["index.html"],
|
||||
path = "/tf-color-scale/demo",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_color_scale",
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"@org_polymer_iron_demo_helpers",
|
||||
"@org_polymer_paper_button",
|
||||
"@org_polymer_paper_styles",
|
||||
"@org_polymer_webcomponentsjs",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
72
tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD
Normal file
72
tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD
Normal file
@ -0,0 +1,72 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load(
|
||||
"//tensorflow/tensorboard:defs.bzl",
|
||||
"tensorboard_ts_development_sources",
|
||||
"tensorboard_ts_devserver",
|
||||
"tensorboard_ts_library",
|
||||
"tensorboard_webcomponent_library",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# TODO(dandelion): Add runner for the test code.
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "tf_color_scale",
|
||||
srcs = ["tf-color-scale.html"],
|
||||
ts_lib_deps = [":ts"],
|
||||
deps = [
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"colorScale.ts",
|
||||
"palettes.ts",
|
||||
],
|
||||
deps = ["//tensorflow/tensorboard/components:common_deps_d3v4"],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "tests",
|
||||
srcs = ["colorScaleTests.ts"],
|
||||
deps = [
|
||||
":ts",
|
||||
"//tensorflow/tensorboard/components:common_deps_d3v4",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "demo",
|
||||
srcs = ["demo.html"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tf_color_scale",
|
||||
"//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_devserver(
|
||||
name = "devserver",
|
||||
manifest = ":dev_sources",
|
||||
serving_path = "/demo_out/bundle.js",
|
||||
static_files = [":demo"],
|
||||
)
|
||||
|
||||
tensorboard_ts_development_sources(
|
||||
name = "dev_sources",
|
||||
deps = [
|
||||
":ts",
|
||||
],
|
||||
)
|
102
tensorflow/tensorboard/components/tf_dashboard_common/BUILD
Normal file
102
tensorflow/tensorboard/components/tf_dashboard_common/BUILD
Normal file
@ -0,0 +1,102 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
webfiles(
|
||||
name = "tf_dashboard_common",
|
||||
srcs = glob(["*.html"]) + [
|
||||
":ts",
|
||||
],
|
||||
path = "/tf-dashboard-common",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||
"//tensorflow/tensorboard/components/tf_imports:plottable",
|
||||
"//tensorflow/tensorboard/components/tf_storage",
|
||||
"//tensorflow/tensorboard/components/vz_sorting",
|
||||
"@org_polymer",
|
||||
"@org_polymer_iron_ajax",
|
||||
"@org_polymer_iron_collapse",
|
||||
"@org_polymer_iron_icons",
|
||||
"@org_polymer_paper_button",
|
||||
"@org_polymer_paper_checkbox",
|
||||
"@org_polymer_paper_dialog",
|
||||
"@org_polymer_paper_dropdown_menu",
|
||||
"@org_polymer_paper_icon_button",
|
||||
"@org_polymer_paper_input",
|
||||
"@org_polymer_paper_item",
|
||||
"@org_polymer_paper_menu",
|
||||
"@org_polymer_paper_slider",
|
||||
"@org_polymer_paper_spinner",
|
||||
"@org_polymer_paper_styles",
|
||||
"@org_polymer_paper_toggle_button",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_typescript_genrule(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"categorizer.ts",
|
||||
"dashboard-behavior.ts",
|
||||
"reload-behavior.ts",
|
||||
],
|
||||
typings = [
|
||||
"@org_definitelytyped//:d3.d.ts",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:ts_typings",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = glob(["*.html"]) + [":legacy_ts"],
|
||||
destdir = "tf-dashboard-common",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:tf_imports",
|
||||
"//tensorflow/tensorboard/components/tf_storage:legacy",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:legacy",
|
||||
"//third_party/javascript/polymer/v1/iron-ajax:lib",
|
||||
"//third_party/javascript/polymer/v1/iron-collapse:lib",
|
||||
"//third_party/javascript/polymer/v1/iron-icons:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-checkbox:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-dialog:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-input:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-item:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-menu:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-slider:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-spinner:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-toggle-button:lib",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = [
|
||||
"categorizer.ts",
|
||||
"dashboard-behavior.ts",
|
||||
"reload-behavior.ts",
|
||||
],
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:common_deps",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
|
||||
],
|
||||
)
|
@ -0,0 +1,31 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# bazel run //third_party/tensorflow/tensorboard/components/tf_dashboard_common/demo
|
||||
webfiles(
|
||||
name = "demo",
|
||||
srcs = [
|
||||
"tf-categorizer-demo.html",
|
||||
"tf-collapsable-pane-demo.html",
|
||||
"tf-multi-checkbox-demo.html",
|
||||
"tf-regex-group-demo.html",
|
||||
],
|
||||
path = "/tf-dashboard-common/demo",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_color_scale",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common",
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"@org_polymer_iron_flex_layout",
|
||||
"@org_polymer_paper_styles",
|
||||
"@org_polymer_webcomponentsjs",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
@ -57,6 +57,8 @@ plugin is requred to implement two functions:
|
||||
</style>
|
||||
</template>
|
||||
<script>
|
||||
"use strict";
|
||||
|
||||
Polymer({
|
||||
is: "tf-chart-scaffold",
|
||||
properties: {
|
||||
|
114
tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD
Normal file
114
tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD
Normal file
@ -0,0 +1,114 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load(
|
||||
"//tensorflow/tensorboard:defs.bzl",
|
||||
"tensorboard_ts_development_sources",
|
||||
"tensorboard_ts_devserver",
|
||||
"tensorboard_ts_library",
|
||||
"tensorboard_webcomponent_library",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "tf_dashboard_common",
|
||||
srcs = [
|
||||
"dashboard-style.html",
|
||||
"run-color-style.html",
|
||||
"scrollbar-style.html",
|
||||
"tensorboard-color.html",
|
||||
"tf-categorizer.html",
|
||||
"tf-collapsable-pane.html",
|
||||
"tf-dashboard.html",
|
||||
"tf-dashboard-layout.html",
|
||||
"tf-downloader.html",
|
||||
"tf-multi-checkbox.html",
|
||||
"tf-no-data-warning.html",
|
||||
"tf-option-selector.html",
|
||||
"tf-panes-helper.html",
|
||||
"tf-regex-group.html",
|
||||
"tf-run-selector.html",
|
||||
"tf-sidebar-helper.html",
|
||||
],
|
||||
ts_lib_deps = [":ts"],
|
||||
deps = [
|
||||
"//third_party/javascript/plottable/v3:lib",
|
||||
"//third_party/javascript/polymer/v1/iron-ajax:lib",
|
||||
"//third_party/javascript/polymer/v1/iron-collapse:lib",
|
||||
"//third_party/javascript/polymer/v1/iron-icons:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-checkbox:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-dialog:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-input:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-item:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-menu:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-slider:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-spinner:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-toggle-button:lib",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "ts",
|
||||
srcs = [
|
||||
"dashboard-behavior.ts",
|
||||
"reload-behavior.ts",
|
||||
"tf-categorizer.ts",
|
||||
"tf-multi-checkbox.ts",
|
||||
"tf-regex-group.ts",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:common_deps_d3v4",
|
||||
"//tensorflow/tensorboard/components/tf_storage_d3v4:ts",
|
||||
"//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "tests",
|
||||
srcs = ["tf-categorizer-tests.ts"],
|
||||
deps = [
|
||||
":ts",
|
||||
"//tensorflow/tensorboard/components:common_deps_d3v4",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "demo",
|
||||
srcs = [
|
||||
"tf-categorizer-demo.html",
|
||||
"tf-collapsable-pane-demo.html",
|
||||
"tf-multi-checkbox-demo.html",
|
||||
"tf-regex-group-demo.html",
|
||||
],
|
||||
deps = [
|
||||
":tf_dashboard_common",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale_d3v4:tf_color_scale",
|
||||
"//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_devserver(
|
||||
name = "devserver",
|
||||
manifest = ":dev_sources",
|
||||
serving_path = "/demo_out/bundle.js",
|
||||
static_files = [":demo"],
|
||||
)
|
||||
|
||||
tensorboard_ts_development_sources(
|
||||
name = "dev_sources",
|
||||
deps = [
|
||||
":ts",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
@ -55,6 +55,8 @@ plugin is requred to implement two functions:
|
||||
</style>
|
||||
</template>
|
||||
<script>
|
||||
"use strict";
|
||||
|
||||
Polymer({
|
||||
is: "tf-chart-scaffold",
|
||||
properties: {
|
||||
|
@ -0,0 +1,63 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
webfiles(
|
||||
name = "tf_distribution_dashboard",
|
||||
srcs = [
|
||||
"tf-distribution-dashboard.html",
|
||||
],
|
||||
path = "/tf-distribution-dashboard",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_backend",
|
||||
"//tensorflow/tensorboard/components/tf_color_scale",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common",
|
||||
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||
"//tensorflow/tensorboard/components/vz_distribution_chart",
|
||||
"@org_polymer",
|
||||
"@org_polymer_iron_collapse",
|
||||
"@org_polymer_paper_icon_button",
|
||||
"@org_polymer_paper_styles",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = [
|
||||
"tf-distribution-dashboard.html",
|
||||
":legacy_ts",
|
||||
],
|
||||
destdir = "tf-distribution-dashboard",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:tf_imports",
|
||||
"//tensorflow/tensorboard/components/tf_backend:legacy",
|
||||
"//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
|
||||
"//tensorflow/tensorboard/components/vz_distribution_chart:legacy",
|
||||
"//third_party/javascript/polymer/v1/iron-collapse:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# This is needed: components/BUILD seeks a legacy_ts rule in this package.
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = [],
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||
)
|
@ -0,0 +1,26 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# bazel run //third_party/tensorflow/tensorboard/components/tf_distribution_dashboard/demo
|
||||
webfiles(
|
||||
name = "demo",
|
||||
srcs = ["index.html"],
|
||||
path = "/tf-distribution-dashboard/demo",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_distribution_dashboard",
|
||||
"//tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data",
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"@org_polymer_iron_demo_helpers",
|
||||
"@org_polymer_paper_styles",
|
||||
"@org_polymer_webcomponentsjs",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
@ -0,0 +1,17 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
webfiles(
|
||||
name = "data",
|
||||
srcs = glob(["*"]),
|
||||
path = "/tf-distribution-dashboard/demo/data",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
49
tensorflow/tensorboard/components/tf_globals/BUILD
Normal file
49
tensorflow/tensorboard/components/tf_globals/BUILD
Normal file
@ -0,0 +1,49 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
# TODO(dandelion): Add webfiles support for the test code.
|
||||
|
||||
webfiles(
|
||||
name = "tf_globals",
|
||||
srcs = [
|
||||
"tf-globals.html",
|
||||
":ts",
|
||||
],
|
||||
path = "/tf-globals",
|
||||
)
|
||||
|
||||
tensorboard_typescript_genrule(
|
||||
name = "ts",
|
||||
srcs = ["globals.ts"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = [
|
||||
"tf-globals.html",
|
||||
":legacy_ts",
|
||||
],
|
||||
destdir = "tf-globals",
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = ["globals.ts"],
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
)
|
16
tensorflow/tensorboard/components/tf_globals_d3v4/BUILD
Normal file
16
tensorflow/tensorboard/components/tf_globals_d3v4/BUILD
Normal file
@ -0,0 +1,16 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "ts",
|
||||
srcs = ["globals.ts"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
65
tensorflow/tensorboard/components/tf_graph_common/BUILD
Normal file
65
tensorflow/tensorboard/components/tf_graph_common/BUILD
Normal file
@ -0,0 +1,65 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
webfiles(
|
||||
name = "tf_graph_common",
|
||||
srcs = [
|
||||
"tf-graph-common.html",
|
||||
":ts",
|
||||
],
|
||||
path = "/tf-graph-common",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||
"//tensorflow/tensorboard/components/tf_imports:dagre",
|
||||
"//tensorflow/tensorboard/components/tf_imports:graphlib",
|
||||
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||
"@org_polymer",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_typescript_genrule(
|
||||
name = "ts",
|
||||
srcs = glob(["*.ts"]),
|
||||
typings = [
|
||||
"@org_definitelytyped//:d3.d.ts",
|
||||
"@org_definitelytyped//:lodash.d.ts",
|
||||
"@org_definitelytyped//:polymer.d.ts",
|
||||
"@org_definitelytyped//:webcomponents.js.d.ts",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(["**"]),
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# MARKED FOR DELETION
|
||||
|
||||
tensorboard_webcomponent_library(
|
||||
name = "legacy",
|
||||
srcs = [
|
||||
"tf-graph-common.html",
|
||||
":legacy_ts",
|
||||
],
|
||||
destdir = "tf-graph-common",
|
||||
deps = [
|
||||
"//tensorflow/tensorboard/components:tf_imports",
|
||||
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tensorboard_ts_library(
|
||||
name = "legacy_ts",
|
||||
srcs = glob(["*.ts"]),
|
||||
deps_mgmt = "off",
|
||||
runtime = "nodejs",
|
||||
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||
)
|
@ -103,6 +103,8 @@ out-hierarchy-params="{{_hierarchyParams}}"
|
||||
</dom-module>
|
||||
|
||||
<script>
|
||||
"use strict";
|
||||
|
||||
(function() {
|
||||
TF.Dashboard.TfGraphDashboard = Polymer({
|
||||
is: 'tf-graph-dashboard',
|
||||
|
@ -169,6 +169,8 @@ h2 {
|
||||
</template>
|
||||
</template>
|
||||
<script>
|
||||
"use strict";
|
||||
|
||||
(function() {
|
||||
Polymer({
|
||||
is: 'tf-graph-info',
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user