Skip to content

Commit

Permalink
unify num_classes to class_num
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Jun 27, 2023
1 parent 068b30f commit 05fecac
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 40 deletions.
20 changes: 10 additions & 10 deletions passl/models/cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def forward(self, x, bool_masked_pos, return_all_tokens=False):
class CAERegressorDecoder(nn.Layer):
def __init__(self,
patch_size=16,
num_classes=8192,
class_num=8192,
embed_dim=768,
depth=6,
num_heads=12,
Expand Down Expand Up @@ -760,7 +760,7 @@ def __init__(self,
if args.num_decoder_self_attention > 0:
self.norm2 = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
embed_dim, class_num) if class_num > 0 else nn.Identity()

self.init_std = init_std

Expand Down Expand Up @@ -907,7 +907,7 @@ def __init__(self,

self.regressor_and_decoder = CAERegressorDecoder(
patch_size=patch_size,
num_classes=args.decoder_num_classes,
class_num=args.decoder_class_num,
embed_dim=args.decoder_embed_dim,
depth=args.regressor_depth,
num_heads=args.decoder_num_heads,
Expand Down Expand Up @@ -1083,7 +1083,7 @@ def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
class_num=1000,
embed_dim=768,
depth=12,
num_heads=12,
Expand All @@ -1103,7 +1103,7 @@ def __init__(self,
lin_probe=False,
args=None):
super().__init__()
self.num_classes = num_classes
self.class_num = class_num
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_mean_pooling = use_mean_pooling

Expand Down Expand Up @@ -1193,8 +1193,8 @@ def __init__(self,
init.trunc_normal_(self.query_token, std=.02)

self.head = nn.Linear(
embed_dim, num_classes,
bias_attr=True) if num_classes > 0 else nn.Identity()
embed_dim, class_num,
bias_attr=True) if class_num > 0 else nn.Identity()

if self.pos_embed is not None and use_abs_pos_emb:
init.trunc_normal_(self.pos_embed, std=.02)
Expand Down Expand Up @@ -1266,10 +1266,10 @@ def no_weight_decay(self):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
def reset_classifier(self, class_num, global_pool=''):
self.class_num = class_num
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.embed_dim, class_num) if class_num > 0 else nn.Identity()

def forward_features(self, x, is_train=True):
x = self.patch_embed(x)
Expand Down
12 changes: 6 additions & 6 deletions passl/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
class_num=1000,
global_pool='token',
embed_dim=768,
depth=12,
Expand All @@ -260,7 +260,7 @@ def __init__(self,
super().__init__()
assert global_pool in ('', 'token', 'avg')

self.num_classes = num_classes
self.class_num = class_num
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim

Expand Down Expand Up @@ -319,7 +319,7 @@ def __init__(self,
num_chs=embed_dim, reduction=0, module='head')
]
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
embed_dim, class_num) if class_num > 0 else nn.Identity()

init.trunc_normal_(self.pos_embed, std=.02)
init.trunc_normal_(self.cls_token, std=.02)
Expand All @@ -340,14 +340,14 @@ def no_weight_decay(self):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
def reset_classifier(self, class_num, global_pool=None):
self.class_num = class_num
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(
self.num_features,
num_classes) if num_classes > 0 else nn.Identity()
class_num) if class_num > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
Expand Down
12 changes: 6 additions & 6 deletions passl/models/convmae/conv_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
class_num=1000,
embed_dim=768,
depth=12,
num_heads=12,
Expand All @@ -195,7 +195,7 @@ def __init__(self,
global_pool=False,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.class_num = class_num
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models

if hybrid_backbone is not None:
Expand Down Expand Up @@ -269,7 +269,7 @@ def __init__(self,

# Classifier head
self.head = nn.Linear(
embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
embed_dim[-1], class_num) if class_num > 0 else nn.Identity()

init.trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
Expand All @@ -294,10 +294,10 @@ def no_weight_decay(self):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
def reset_classifier(self, class_num, global_pool=''):
self.class_num = class_num
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.embed_dim, class_num) if class_num > 0 else nn.Identity()

def forward_features(self, x):
B = x.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion passl/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ConvNeXt(Model):
A Paddle impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
class_num (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
Expand Down
4 changes: 2 additions & 2 deletions passl/models/dino/dino_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def forward(self, x, return_attention=False):
class DINOVisionTransformer(nn.Layer):
""" DINO Vision Transformer """

def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
def __init__(self, img_size=[224], patch_size=16, in_chans=3, class_num=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, n_last_blocks=1, avgpool_patchtokens=False, **kwargs):
super().__init__()
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, emb
self.norm = norm_layer(embed_dim)

# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(embed_dim, class_num) if class_num > 0 else nn.Identity()

self.n_last_blocks = n_last_blocks
self.avgpool_patchtokens = avgpool_patchtokens
Expand Down
2 changes: 1 addition & 1 deletion passl/models/simsiam.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, base_encoder, dim=2048, pred_dim=512):
super(SimSiamPretain, self).__init__()

# create the encoder
# num_classes is the output fc dimension, zero-initialize last BNs
# class_num is the output fc dimension, zero-initialize last BNs
self.encoder = base_encoder(class_num=dim, zero_init_residual=True)

# build a 3-layer projector
Expand Down
14 changes: 7 additions & 7 deletions passl/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class SwinTransformer(Model):
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
class_num (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
Expand All @@ -487,7 +487,7 @@ def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
class_num=1000,
global_pool='avg',
embed_dim=96,
depths=(2, 2, 6, 2),
Expand All @@ -506,7 +506,7 @@ def __init__(self,
**kwargs):
super().__init__()
assert global_pool in ('', 'avg')
self.num_classes = num_classes
self.class_num = class_num
self.global_pool = global_pool
self.num_layers = len(depths)
self.embed_dim = embed_dim
Expand Down Expand Up @@ -560,7 +560,7 @@ def __init__(self,
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(
self.num_features,
num_classes) if num_classes > 0 else nn.Identity()
class_num) if class_num > 0 else nn.Identity()

if self.absolute_pos_embed is not None:
init.trunc_normal_(self.absolute_pos_embed, std=.02)
Expand Down Expand Up @@ -595,14 +595,14 @@ def group_matcher(self, coarse=False):
def get_classifier(self):
return self.head

def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
def reset_classifier(self, class_num, global_pool=None):
self.class_num = class_num
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(
self.num_features,
num_classes) if num_classes > 0 else nn.Identity()
class_num) if class_num > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
Expand Down
2 changes: 1 addition & 1 deletion tasks/ssl/cae/main_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def mixup_collate_fn(batch):
data_loader_val = None

model = models_cae.__dict__[args.model](
num_classes=args.nb_classes,
class_num=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
attn_drop_rate=args.attn_drop_rate,
Expand Down
2 changes: 1 addition & 1 deletion tasks/ssl/cae/main_linprobe.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def main(args):
use_shared_memory=args.pin_mem, )

model = models_cae.__dict__[args.model](
num_classes=args.nb_classes,
class_num=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
attn_drop_rate=args.attn_drop_rate,
Expand Down
2 changes: 1 addition & 1 deletion tasks/ssl/cae/main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def get_args():
type=int,
help='Number of heads for decoder')
parser.add_argument(
'--decoder_num_classes',
'--decoder_class_num',
default=8192,
type=int,
help='Number of classes for decoder')
Expand Down
4 changes: 2 additions & 2 deletions tasks/ssl/mae/main_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,13 @@ def mixup_collate_fn(batch):

if 'convvit' in args.model:
model = models_convmae.__dict__[args.model](
num_classes=args.nb_classes,
class_num=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool, )
num_layers = len(model.blocks3) + 1
else:
model = models_mae.__dict__[args.model](
num_classes=args.nb_classes,
class_num=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool, )
num_layers = len(model.blocks) + 1
Expand Down
4 changes: 2 additions & 2 deletions tasks/ssl/mae/main_linprobe.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ def main(args):

if 'convvit' in args.model:
model = models_convmae.__dict__[args.model](
num_classes=args.nb_classes,
class_num=args.nb_classes,
global_pool=args.global_pool, )
else:
model = models_mae.__dict__[args.model](
num_classes=args.nb_classes,
class_num=args.nb_classes,
global_pool=args.global_pool, )

if args.finetune and not args.eval:
Expand Down

0 comments on commit 05fecac

Please sign in to comment.