Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
jialingqu 4aca37f536 | 3 years ago | |
---|---|---|
scripts | 3 years ago | |
src | 3 years ago | |
README.md | 3 years ago | |
eval.py | 3 years ago | |
train.py | 3 years ago |
PyTorch code for CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning (Few-Shot Learning part)
For Zero-Shot Learning part, please visit here.
Relation-Net contains 2 parts named Encoder and Relation. The former one has 4 convolution layers, the latter one has 2 convolution layers and 2 linear layers.
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: omniglot
Dataset size 4.02M,32462 28*28 in 1622 classes
Data format .png files
The directory structure is as follows:
└─Data
├─miniImagenet
│
│
└─omniglot_resized
Alphabet_of_the_Magi
Angelic
After installing MindSpore via the official website, you can start training and evaluation as follows:
# enter script dir, train RelationNet
sh run_train_ascend.sh
# enter script dir, evaluate RelationNet
sh run_eval_ascend.sh
├── cv
├── FSL
├── README.md // descriptions about lenet
├── scripts
│ ├──run_train_ascend.sh // train in ascend
│ ├──run_eval_ascend.sh // evaluate in ascend
├── src
│ ├──config.py // parameter configuration
│ ├──dataset.py // creating dataset
│ ├──lr_generator.py // generate lr
│ ├──relationnet.py // relationnet architecture
│ ├──net_train.py // train model
├── train.py // training script
├── eval.py // evaluation script
Major parameters in train.py and config.py as follows:
--class_num: the number of class we use in one step.
--sample_num_per_class: the number of quert data we extract from one class.
--batch_num_per_class: the number of support data we extract from one class.
--data_path: The absolute full path to the train and evaluation datasets.
--episode: Total training epochs.
--test_episode: Total testing episodes
--learning_rate: Learning rate
--device_target: Device where the code will be implemented. Optional values
are "Ascend", "GPU", "CPU".
--save_dir: The absolute full path to the checkpoint file saved
after training.
--data_path: Path where the dataset is saved
--cloud: whether run on the cloud~~~~
python train.py --data_path Data --ckpt_path ckpt > log.txt 2>&1 &
# or enter script dir, and run the script
sh run_train_ascend.sh
python train.py
After training, the loss value will be achieved as follows:
# grep train.log
...
init data folders
init neural networks
init optim,loss
init loss function and grads
==========Training==========
-----Episode 100/1000000-----
Episode: 100 Train, Loss(MSE): 0.16057138
-----Episode 200/1000000-----
Episode: 200 Train, Loss(MSE): 0.16390544
-----Episode 300/1000000-----
Episode: 300 Train, Loss(MSE): 0.1247341
...
The model checkpoint will be saved in the current directory.
Before running the command below, please check the checkpoint path used for evaluation.
python eval.py --data_path Data > log.txt 2>&1 &
# or enter script dir, and run the script
sh run_eval_ascend.sh
python train.py
You can view the results through the file "log.txt". The accuracy of the test dataset will be as follows:
# grep "Accuracy: " log.txt
'Accuracy': 0.9842
Parameters | RelationNet |
---|---|
Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
uploaded Date | 03/26/2021 (month/day/year) |
MindSpore Version | 1.1.1 |
Dataset | OMNIGLOT |
Training Parameters | episode=1000000, class_num = 5, lr=0.001 |
Optimizer | Adam |
Loss Function | MSE |
outputs | Accuracy |
Loss | 0.002 |
Speed | 6 s/episode |
Total time | 16 h 28m (single device) |
Checkpoint for Fine tuning | 875k (.ckpt file) |
Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/FSL |
In dataset.py, we set the seed inside omniglot_character_folders
function.
In net_train.py, we set the random.choice inside train
function.
Please check the official homepage.
Learning to Compare: Relation Network for Few-shot Learning by Mindspore
Python other
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》