thinh-researcher commited on
Commit
695e3cb
·
1 Parent(s): e89a14c
Files changed (1) hide show
  1. run.py +3 -5
run.py CHANGED
@@ -5,15 +5,13 @@ from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassif
5
 
6
  ResnetConfig.register_for_auto_class()
7
  ResnetModel.register_for_auto_class("AutoModel")
8
- ResnetModelForImageClassification.register_for_auto_class(
9
- "AutoModelForImageClassification"
10
- )
11
 
12
  resnet50d_config = ResnetConfig(
13
  block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
14
  )
15
  resnet50d = ResnetModelForImageClassification(resnet50d_config)
16
  pretrained_model = timm.create_model("resnet50d", pretrained=True)
17
- resnet50d.model.load_state_dict(pretrained_model.state_dict())
18
 
19
- resnet50d.push_to_hub("custom-resnet50d")
 
5
 
6
  ResnetConfig.register_for_auto_class()
7
  ResnetModel.register_for_auto_class("AutoModel")
8
+ ResnetModelForImageClassification.register_for_auto_class("AutoModel")
 
 
9
 
10
  resnet50d_config = ResnetConfig(
11
  block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
12
  )
13
  resnet50d = ResnetModelForImageClassification(resnet50d_config)
14
  pretrained_model = timm.create_model("resnet50d", pretrained=True)
15
+ resnet50d.model.model.load_state_dict(pretrained_model.state_dict())
16
 
17
+ resnet50d.push_to_hub("RGBD-SOD/custom-resnet50d")