Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
jhzj 882e1b10a6 | 3 years ago | |
---|---|---|
src | 3 years ago | |
README.md | 3 years ago | |
__init__.py | 3 years ago | |
model_init.py | 3 years ago | |
test.py | 3 years ago | |
train.py | 3 years ago |
PyTorch code for NeuralIPS 2017 paper: Prototypical Networks for Few-shot Learning
Proto-Net contains 2 parts named Encoder and Relation. The former one has 4 convolution layers, the latter one has 2 convolution layers and 2 linear layers.
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: omniglot
Dataset size 4.02M,32462 28*28 in 1622 classes
Data format .png files
The directory structure is as follows:
└─Data
├─raw
├─spilts
│ vinyals
│ test.txt
│ train.txt
│ val.txt
│ trainval.txt
└─data
Alphabet_of_the_Magi
Angelic
After installing MindSpore via the official website, you can start training and evaluation as follows:
# enter script dir, train RelationNet
sh run_train_ascend.sh
# enter script dir, evaluate RelationNet
sh run_eval_ascend.sh
├── cv
├── ProtoNet
├── README.md // descriptions about lenet
├── scripts
│ ├──run_train_ascend.sh // train in ascend
│ ├──run_eval_ascend.sh // evaluate in ascend
├── src
│ ├──parser_util.py // parameter configuration
│ ├──dataset.py // creating dataset
│ ├──IterDatasetGenerator.py // generate dataset
│ ├──protonet.py // relationnet architecture
│ ├──PrototypicalLoss.py // loss function
├── train.py // training script
├── eval.py // evaluation script
Major parameters in train.py and config.py as follows:
--class_num: the number of class we use in one step.
--sample_num_per_class: the number of quert data we extract from one class.
--batch_num_per_class: the number of support data we extract from one class.
--data_path: The absolute full path to the train and evaluation datasets.
--episode: Total training epochs.
--test_episode: Total testing episodes
--learning_rate: Learning rate
--device_target: Device where the code will be implemented.
--save_dir: The absolute full path to the checkpoint file saved
after training.
--data_path: Path where the dataset is saved
python train.py --data_path Data --ckpt_path ckpt > log.txt 2>&1 &
# or enter script dir, and run the script
sh run_train_ascend.sh
python train.py
The model checkpoint will be saved in the current directory.
Before running the command below, please check the checkpoint path used for evaluation.
python eval.py --data_path Data > log.txt 2>&1 &
# or enter script dir, and run the script
sh run_eval_ascend.sh
python train.py
Test Acc: 0.9954400658607483 Loss: 0.02102319709956646
Parameters | ProtoNet |
---|---|
Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
uploaded Date | 03/26/2021 (month/day/year) |
MindSpore Version | 1.1.1 |
Dataset | OMNIGLOT |
Training Parameters | episode=500, class_num = 5, lr=0.001 |
Optimizer | Adam |
Loss Function | Prototypicalloss |
outputs | Accuracy |
Loss | 0.002 |
Speed | 215 ms/step |
Total time | 3 h 23m (8p) |
Checkpoint for Fine tuning | 440 KB (.ckpt file) |
Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/pro1 |
Please check the official homepage.
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》