docs/source/feature_extraction.rst
.. currentmodule:: torchvision.models.feature_extraction
The torchvision.models.feature_extraction package contains
feature extraction utilities that let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:
Torchvision provides :func:create_feature_extractor for this purpose.
It works by following roughly these steps:
|
The torch.fx documentation <https://pytorch.org/docs/stable/fx.html>_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.
.. _about-node-names:
About Node Names
In order to specify which nodes should be output nodes for extracted
features, one should be familiar with the node naming convention used here
(which differs slightly from that used in torch.fx). A node name is
specified as a . separated path walking the module hierarchy from top level
module down to leaf operation or leaf module. For instance "layer4.2.relu"
in ResNet-50 represents the output of the ReLU of the 2nd block of the 4th
layer of the ResNet module. Here are some finer points to keep in mind:
create_feature_extractor, you may
provide a truncated version of a node name as a shortcut. To see how this
works, try creating a ResNet-50 model and printing the node names with
train_nodes, _ = get_graph_node_names(model) print(train_nodes) and
observe that the last node pertaining to layer4 is
"layer4.2.relu_2". One may specify "layer4.2.relu_2" as the return
node, or just "layer4" as this, by convention, refers to the last node
(in order of execution) of layer4._{int} postfix to disambiguate. For instance, maybe the
addition (+) operation is used three times in the same forward
method. Then there would be "path.to.module.add",
"path.to.module.add_1", "path.to.module.add_2". The counter is
maintained within the scope of the direct parent. So in ResNet-50 there is
a "layer4.1.add" and a "layer4.2.add". Because the addition
operations reside in different blocks, there is no need for a postfix to
disambiguate.An Example
Here is an example of how we might extract features for MaskRCNN:
.. code-block:: python
import torch from torchvision.models import resnet50 from torchvision.models.feature_extraction import get_graph_node_names from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.detection.mask_rcnn import MaskRCNN from torchvision.models.detection.backbone_utils import LastLevelMaxPool from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
m = resnet50() train_nodes, eval_nodes = get_graph_node_names(resnet50())
train_nodes and eval_nodes are the samereturn_nodes = { # node_name: user-specified key for output dict 'layer1.2.relu_2': 'layer1', 'layer2.3.relu_2': 'layer2', 'layer3.5.relu_2': 'layer3', 'layer4.2.relu_2': 'layer4', }
create_feature_extractor can also accept truncated node specificationsreturn_nodes = { 'layer1': 'layer1', 'layer2': 'layer2', 'layer3': 'layer3', 'layer4': 'layer4', }
create_feature_extractor(m, return_nodes=return_nodes)
class Resnet50WithFPN(torch.nn.Module): def init(self): super(Resnet50WithFPN, self).init() # Get a resnet50 backbone m = resnet50() # Extract 4 main layers (note: MaskRCNN needs this particular name # mapping for return nodes) self.body = create_feature_extractor( m, return_nodes={f'layer{k}': str(v) for v, k in enumerate([1, 2, 3, 4])}) # Dry run to get number of channels for FPN inp = torch.randn(2, 3, 224, 224) with torch.no_grad(): out = self.body(inp) in_channels_list = [o.shape[1] for o in out.values()] # Build FPN self.out_channels = 256 self.fpn = FeaturePyramidNetwork( in_channels_list, out_channels=self.out_channels, extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
.. autosummary:: :toctree: generated/ :template: function.rst
create_feature_extractor
get_graph_node_names