|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import json
- import numpy as np
-
- import time
- import os
- from six.moves import cPickle
-
- import opts
- import models
- from dataloader import *
- from dataloaderraw import *
- import eval_utils
- import argparse
- import misc.utils as utils
- import torch
- import logging
- from models import AoAModel
-
-
-
- class ImageCaptionModel():
- def __init__(self, infos_path="/ImageCaptioning/ImageCaptioning/log/log_aoanet_rl/infos_aoanet.pkl", model_path="/ImageCaptioning/ImageCaptioning/log/log_aoanet_rl/model.pth"):
- """
- infos_path: 训练好的模型信息文件,已保存在 log/log_aoanet_rl/infos_aoanet.pkl
- model_path:训练好的模型,已存在 log/log_aoanet_rl/model.pth
- """
- logging.info("Loading infos from {}".format(infos_path))
- logging.info("Loading model from {}".format(infos_path))
- with open(infos_path, 'rb') as f:
- self.infos = utils.pickle_load(f)
-
- # Input arguments and options
- parser = argparse.ArgumentParser()
- opts.add_eval_options(parser)
- opt = parser.parse_args([])
- # opt = parser.parse_args(["batch_size","0", "beam_size","2",
- # "block_trigrams", "0", "coco_json", '', "decoding_constraint", "0", "diversity_lambda", "0.5",
- # "dump_images", "1", "dump_json", "1", "dump_path", "0", "group_size", "1", "id", '',
- # "image_folder", '', "image_root", '', "input_att_dir", '', "input_box_dir", '',
- # "input_fc_dir", '', "input_json", '', "input_label_h5", '',
- # "language_eval", "0", "length_penalty", '', "max_length", "20", "num_images", "-1",
- # "remove_bad_endings", "0", "sample_method", 'greedy', "split", 'test',
- # "temperature", "1.0", "verbose_beam", "1", "verbose_loss", "0"])
- replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']
- ignore = ['start_from']
- for k in vars(self.infos['opt']).keys():
- if k in replace:
- setattr(opt, k, getattr(opt, k) or getattr(self.infos['opt'], k, ''))
- elif k not in ignore:
- if not k in vars(opt):
- vars(opt).update({k: vars(self.infos['opt'])[k]}) # copy over options from model
-
-
- self.vocab = self.infos['vocab']
- opt.vocab = self.vocab
- self.model = AoAModel(opt)
- del opt.vocab
- self.model.load_state_dict(torch.load(model_path))
- self.crit = utils.LanguageModelCriterion() #评测指标
-
- def __call__(self, fc_feats, att_feats, att_masks, mode="sample"):
- """
- fc_feats: the preprocessed fc feats, [batch_size, fc_feats_dimension]
- att_feats: the preprocessed att feats,[batch_size, num_feats, att_feats_dimension]
- att_masks: 图像特征的mask,用于multi-head attention, [batch_size, num_feats]
- """
- self.model.cuda()
- self.model.eval()
-
- with torch.no_grad():
- seq = self.model(fc_feats, att_feats, att_masks, mode='sample')[0].data
- sents = utils.decode_sequence(self.vocab, seq)
-
- return sents
|