site/en/guide/migrate/logging_stop_hook.ipynb
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
View on TensorFlow.org</a>
Run in Google Colab</a>
View source on GitHub</a>
In TensorFlow 1, you use tf.estimator.LoggingTensorHook to monitor and log tensors, while tf.estimator.StopAtStepHook helps stop training at a specified step when training with tf.estimator.Estimator. This notebook demonstrates how to migrate from these APIs to their equivalents in TensorFlow 2 using custom Keras callbacks (tf.keras.callbacks.Callback) with Model.fit.
Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides. For migrating from SessionRunHook in TensorFlow 1 to Keras callbacks in TensorFlow 2, check out the Migrate training with assisted logic guide.
Start with imports and a simple dataset for demonstration purposes:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
In TensorFlow 1, you define various hooks to control the training behavior. Then, you pass these hooks to tf.estimator.EstimatorSpec.
In the example below:
tf.estimator.LoggingTensorHook (tf.train.LoggingTensorHook is its alias).tf.estimator.StopAtStepHook (tf.train.StopAtStepHook is its alias).def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure tensor monitoring and training stopping by defining custom Keras tf.keras.callbacks.Callbacks. Then, you pass them to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the Writing your own callbacks guide.)
In the example below:
StopAtStepHook, define a custom callback (named StopAtStepCallback below) where you override the on_batch_end method to stop training after a certain number of steps.LoggingTensorHook behavior, define a custom callback (LoggingTensorCallback) where you record and output the logged tensors manually, since accessing to tensors by names is not supported. You can also implement the logging frequency inside the custom callback. The example below will print the weights every two steps. Other strategies like logging every N seconds are also possible.class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
When finished, pass the new callbacks—StopAtStepCallback and LoggingTensorCallback—to the callbacks parameter of Model.fit:
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
Learn more about callbacks in:
tf.keras.callbacks.CallbackYou may also find the following migration-related resources useful:
tf.keras.callbacks.EarlyStopping is a built-in early stopping callbackSessionRunHook in TensorFlow 1 to Keras callbacks in TensorFlow 2