|
- """
- #
- # Copyright (c) 2022 salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
- #
-
- Integration tests for BLIP2 models.
- """
-
- import pytest
- import torch
- from lavis.models import load_model, load_model_and_preprocess
- from PIL import Image
-
- # setup device to use
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # load sample image
- raw_image = Image.open("docs/_static/merlion.png").convert("RGB")
-
-
- class TestBlip2:
- def test_blip2_opt2p7b(self):
- # loads BLIP2-OPT-2.7b caption model, without finetuning on coco.
- model, vis_processors, _ = load_model_and_preprocess(
- name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
- )
-
- # preprocess the image
- # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
- image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
-
- # generate caption
- caption = model.generate({"image": image})
-
- assert caption == ["the merlion fountain in singapore"]
-
- # generate multiple captions
- captions = model.generate({"image": image}, num_captions=3)
-
- assert len(captions) == 3
-
- def test_blip2_opt2p7b_coco(self):
- # loads BLIP2-OPT-2.7b caption model,
- model, vis_processors, _ = load_model_and_preprocess(
- name="blip2_opt",
- model_type="caption_coco_opt2.7b",
- is_eval=True,
- device=device,
- )
-
- # preprocess the image
- # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
- image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
-
- # generate caption
- caption = model.generate({"image": image})
-
- assert caption == ["a statue of a mermaid spraying water into the air"]
-
- # generate multiple captions
- captions = model.generate({"image": image}, num_captions=3)
-
- assert len(captions) == 3
-
- def test_blip2_opt6p7b(self):
- # loads BLIP2-OPT-2.7b caption model,
- model, vis_processors, _ = load_model_and_preprocess(
- name="blip2_opt", model_type="pretrain_opt6.7b", is_eval=True, device=device
- )
-
- # preprocess the image
- # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
- image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
-
- # generate caption
- caption = model.generate({"image": image})
-
- assert caption == ["a statue of a merlion in front of a water fountain"]
-
- # generate multiple captions
- captions = model.generate({"image": image}, num_captions=3)
-
- assert len(captions) == 3
-
- def test_blip2_opt6p7b_coco(self):
- # loads BLIP2-OPT-2.7b caption model,
- model, vis_processors, _ = load_model_and_preprocess(
- name="blip2_opt",
- model_type="caption_coco_opt6.7b",
- is_eval=True,
- device=device,
- )
-
- # preprocess the image
- # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
- image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
-
- # generate caption
- caption = model.generate({"image": image})
-
- assert caption == ["a large fountain spraying water into the air"]
-
- # generate multiple captions
- captions = model.generate({"image": image}, num_captions=3)
-
- assert len(captions) == 3
-
- def test_blip2_flant5xl(self):
- # loads BLIP2-FLAN-T5XL caption model,
- model, vis_processors, _ = load_model_and_preprocess(
- name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device
- )
-
- # preprocess the image
- # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
- image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
-
- # generate caption
- caption = model.generate({"image": image})
-
- assert caption == ["marina bay sands, singapore"]
-
- # generate multiple captions
- captions = model.generate({"image": image}, num_captions=3)
-
- assert len(captions) == 3
-
- def test_blip2_flant5xxl(self):
- # loads BLIP2-FLAN-T5XXL caption model,
- model, vis_processors, _ = load_model_and_preprocess(
- name="blip2_t5",
- model_type="pretrain_flant5xxl",
- is_eval=True,
- device=device,
- )
-
- # preprocess the image
- # vis_processors stores image transforms for "train" and "eval" (validation / testing / inference)
- image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
-
- # generate caption
- caption = model.generate({"image": image})
-
- assert caption == ["the merlion statue in singapore"]
-
- # generate multiple captions
- captions = model.generate({"image": image}, num_captions=3)
-
- assert len(captions) == 3
|