|
- import os
- import argparse
- import torch
- import numpy as np
- from tqdm import tqdm
- from torch.utils.data import DataLoader
- from lavis.models.eva_vit import create_eva_vit_g
- from lavis.datasets.builders.medical_builder import CLEF2023CapBuilder
-
- for key in ['HF_HOME', 'TORCH_HOME', 'TRANSFORMERS_CACHE', 'HUGGINGFACE_HUB_CACHE']:
- os.environ[key] = '/userhome/.cache'
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--img_size", type=int, default=224, choices=[224, 364])
- parser.add_argument('--ckpt', type=str)
- parser.add_argument("--batch_size", type=int, default=32)
- args = parser.parse_args()
-
- print('Loading datasets ...')
- builder = CLEF2023CapBuilder()
- builder.config['vis_processor'] = {'train': {'name': 'blip_image_train', 'image_size': args.img_size}, 'eval': {'name': 'blip_image_eval', 'image_size': args.img_size}}
- print(builder.config)
-
- datasets = builder.build_datasets()
-
- save_path = os.path.join(datasets['train'].vis_root, 'features', f'{args.img_size}'+('_ft' if args.ckpt else ''))
- os.makedirs(save_path, exist_ok=True)
- print('Features will be saved to', save_path)
-
- print('Loading vit ...')
- model = create_eva_vit_g(
- img_size=args.img_size,
- precision='fp16',
- )
-
- if args.ckpt is not None:
- print('- load state_dict from', args.ckpt)
- now_state_dict = model.state_dict()
- state_dict = torch.load(args.ckpt, 'cpu')['model']
- to_load = {}
- for k, v in state_dict.items():
- new_k = k.replace('visual_encoder.', '')
- if new_k in now_state_dict:
- to_load[new_k] = v
- print(len(to_load))
- model.load_state_dict(to_load, strict=True)
- torch.save(model.state_dict(), 'finetuned_eva_vit_g.pth')
-
- model = model.eval().to('cuda')
-
- print('Start extracting features')
- for mode in tqdm(datasets.keys()):
- loader = DataLoader(datasets[mode], batch_size=args.batch_size)
- for batch in tqdm(loader):
- with torch.no_grad():
- with torch.cuda.amp.autocast(dtype=torch.float16):
- image_embeds = model(batch["image"].to('cuda'))
-
- image_embeds = image_embeds.cpu().numpy()
- for image_id, embed in zip(batch['image_id'], image_embeds):
- sp = os.path.join(save_path, f'{image_id}.npy')
- np.save(sp, embed)
|