Skip to content

Commit

Permalink
Access to multi-level features (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
vturrisi authored Dec 1, 2021
1 parent 0ac7b8f commit f76ec9e
Show file tree
Hide file tree
Showing 3 changed files with 605 additions and 14 deletions.
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ While the library is self-contained, it is possible to use the models outside of
---

## News
* **[Dec 01 2021]**: :fountain: Added support for getting mid-level features and [PoolFormer](https://arxiv.org/abs/2111.11418).
* **[Nov 29 2021]**: :bangbang: Breaking changes! Update your versions!!!
* **[Nov 29 2021]**: :book: New tutorials!
* **[Nov 29 2021]**: :houses: Added offline K-NN and offline UMAP.
* **[Nov 29 2021]**: :rotating_light: Updated PyTorch and PyTorch Lightning versions. 10% faster.
* **[Nov 29 2021]**: :beers: Added code of conduct, contribution instructions, issue templates and UMAP tutorial.
* **[Nov 23 2021]**: :space_invader: Added [VIbCReg](https://arxiv.org/abs/2109.00783).
Expand All @@ -27,9 +29,6 @@ While the library is self-contained, it is possible to use the models outside of
* **[Sep 17 2021]**: :robot: Added [ViT](https://arxiv.org/abs/2010.11929) and [Swin](https://arxiv.org/abs/2103.14030).
* **[Sep 13 2021]**: :book: Improved [Docs](https://solo-learn.readthedocs.io/en/latest/?badge=latest) and added tutorials for [pretraining](https://solo-learn.readthedocs.io/en/latest/tutorials/overview.html) and [offline linear eval](https://solo-learn.readthedocs.io/en/latest/tutorials/offline_linear_eval.html).
* **[Aug 13 2021]**: :whale: [DeepCluster V2](https://arxiv.org/abs/2006.09882) is now available.
* **[Jul 31 2021]**: :hedgehog: [ReSSL](https://arxiv.org/abs/2107.09282) is now available.
* **[Jul 21 2021]**: :test_tube: Added Custom Dataset support.
* **[Jul 21 2021]**: :carousel_horse: Added AutoUMAP.

---

Expand All @@ -55,18 +54,27 @@ While the library is self-contained, it is possible to use the models outside of

## Extra flavor

# Multiple backbones
* [ResNet](https://arxiv.org/abs/1512.03385)
* [ViT](https://arxiv.org/abs/2010.11929)
* [Swin](https://arxiv.org/abs/2103.14030)
* [PoolFormer](https://arxiv.org/abs/2111.11418)

### Data
* Increased data processing speed by up to 100% using [Nvidia Dali](https://github.com/NVIDIA/DALI).
* Asymmetric and symmetric augmentations.
* Flexible augmentations.

### Evaluation and logging
* Online linear evaluation via stop-gradient for easier debugging and prototyping (optionally available for the momentum backbone as well).
* Online Knn evaluation.
* Online and offlfine K-NN evaluation.
* Normal offline linear evaluation.
* All the perks of PyTorch Lightning (mixed precision, gradient accumulation, clipping, automatic logging and much more).
* Easy-to-extend modular code structure.
* Custom model logging with a simpler file organization.
* Automatic feature space visualization with UMAP.
* Offline UMAP.
* Common metrics and more to come...

### Training tricks
* Multi-cropping dataloading following [SwAV](https://arxiv.org/abs/2006.09882):
* **Note**: currently, only SimCLR supports this.
Expand Down
69 changes: 61 additions & 8 deletions solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@
import torch.nn as nn
import torch.nn.functional as F
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from solo.utils.knn import WeightedKNNClassifier
from solo.utils.lars import LARSWrapper
from solo.utils.metrics import accuracy_at_k, weighted_mean
from solo.utils.momentum import MomentumUpdater, initialize_momentum_params
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
from torchvision.models import resnet18, resnet50
from solo.utils.backbones import (
poolformer_m36,
poolformer_m48,
poolformer_s12,
poolformer_s24,
poolformer_s36,
swin_base,
swin_large,
swin_small,
Expand All @@ -42,6 +41,13 @@
vit_small,
vit_tiny,
)
from solo.utils.knn import WeightedKNNClassifier
from solo.utils.lars import LARSWrapper
from solo.utils.metrics import accuracy_at_k, weighted_mean
from solo.utils.momentum import MomentumUpdater, initialize_momentum_params
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
from torchvision.models import resnet18, resnet50
from torchvision.models.feature_extraction import create_feature_extractor


def static_lr(
Expand All @@ -66,6 +72,28 @@ class BaseMethod(pl.LightningModule):
"swin_small": swin_small,
"swin_base": swin_base,
"swin_large": swin_large,
"poolformer_s12": poolformer_s12,
"poolformer_s24": poolformer_s24,
"poolformer_s36": poolformer_s36,
"poolformer_m36": poolformer_m36,
"poolformer_m48": poolformer_m48,
}

_NODE_NAMES = {
"resnet18": [
"layer1.1.relu_1",
"layer2.1.relu_1",
"layer3.1.relu_1",
"layer4.1.relu_1",
"flatten",
],
"resnet50": [
"layer1.2.relu_2",
"layer2.3.relu_2",
"layer3.5.relu_2",
"layer4.2.relu_2",
"flatten",
],
}

def __init__(
Expand Down Expand Up @@ -221,6 +249,14 @@ def __init__(
else:
self.features_dim = self.backbone.num_features

self.node_names = BaseMethod._NODE_NAMES.get(backbone, None)
self.supports_multilevel = self.node_names is not None
if self.supports_multilevel:
self.backbone = create_feature_extractor(
self.backbone,
return_nodes=self.node_names,
)

self.classifier = nn.Linear(self.features_dim, num_classes)

if self.knn_eval:
Expand Down Expand Up @@ -402,8 +438,23 @@ def base_forward(self, X: torch.Tensor) -> Dict:
"""

feats = self.backbone(X)
multilevel_feats = {}
if self.supports_multilevel:
# features for multiple levels of the backbone
multilevel_feats = feats

# parses features and divide into mid-level features or final representations
multilevel_feats = [multilevel_feats[n] for n in self.node_names]

feats = multilevel_feats.pop(-1)
multilevel_feats = {f"feats-lvl{i}": f for i, f in enumerate(multilevel_feats)}

logits = self.classifier(feats.detach())
return {"logits": logits, "feats": feats}
return {
"logits": logits,
"feats": feats,
**multilevel_feats,
}

def _base_shared_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
"""Forwards a batch of images X and computes the classification loss, the logits, the
Expand Down Expand Up @@ -451,7 +502,9 @@ def training_step(self, batch: List[Any], batch_idx: int) -> Dict[str, Any]:
outs = {k: [out[k] for out in outs] for k in outs[0].keys()}

if self.multicrop:
outs["feats"].extend([self.backbone(x) for x in X[self.num_large_crops :]])
outs["feats"].extend(
[self.backbone(x)[self.node_names[-1]] for x in X[self.num_large_crops :]]
)

# loss and stats
outs["loss"] = sum(outs["loss"]) / self.num_large_crops
Expand Down
Loading

0 comments on commit f76ec9e

Please sign in to comment.