|
- import torch
- from transformers import AutoConfig, AutoModelForCausalLM
- from janus.models import MultiModalityCausalLM, VLChatProcessor
- from janus.utils.io import load_pil_images
- from PIL import Image
- from fastapi import FastAPI
- from generate import ImageRequest, generate, pipe_init, get_sd_model_dir, translation
-
- import numpy as np
- import os
- import time
- from c2net.context import prepare
- import argparse
- from io import BytesIO
- import base64
-
- app = FastAPI()
-
- c2net_context = prepare()
-
- files = [f for f in os.listdir(c2net_context.pretrain_model_path) if os.path.isdir(c2net_context.pretrain_model_path+ f"/{f}")]
- model_name = ""
- for file in files:
- if (not file.startswith(".") ) and len(os.listdir(c2net_context.pretrain_model_path+ f"/{file}")) > 0:
- model_name = file
- break
-
- print(f"当前模型:{model_name}")
-
- model_path = c2net_context.pretrain_model_path + f"/{model_name}"
-
- config = AutoConfig.from_pretrained(model_path)
- language_config = config.language_config
- language_config._attn_implementation = 'eager'
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
- language_config=language_config,
- trust_remote_code=True)
- if torch.cuda.is_available():
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
- else:
- vl_gpt = vl_gpt.to(torch.float16)
-
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
- tokenizer = vl_chat_processor.tokenizer
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
- def generate(input_ids,
- width,
- height,
- temperature: float = 1,
- parallel_size: int = 5,
- cfg_weight: float = 5,
- image_token_num_per_image: int = 576,
- patch_size: int = 16):
- # Clear CUDA cache before generating
- torch.cuda.empty_cache()
-
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
- for i in range(parallel_size * 2):
- tokens[i, :] = input_ids
- if i % 2 != 0:
- tokens[i, 1:-1] = vl_chat_processor.pad_id
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
-
- pkv = None
- for i in range(image_token_num_per_image):
- with torch.no_grad():
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
- use_cache=True,
- past_key_values=pkv)
- pkv = outputs.past_key_values
- hidden_states = outputs.last_hidden_state
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
- logit_cond = logits[0::2, :]
- logit_uncond = logits[1::2, :]
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
- probs = torch.softmax(logits / temperature, dim=-1)
- next_token = torch.multinomial(probs, num_samples=1)
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
-
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
- inputs_embeds = img_embeds.unsqueeze(dim=1)
-
-
-
- patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
-
- return generated_tokens.to(dtype=torch.int), patches
-
- def unpack(dec, width, height, parallel_size=5):
- dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
- dec = np.clip((dec + 1) / 2 * 255, 0, 255)
-
- visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
- visual_img[:, :, :] = dec
-
- return visual_img
-
- base_url = os.getenv('OCTOPUS_NOTEBOOK_BASE_URL')
-
- @torch.inference_mode()
- # @spaces.GPU(duration=120) # Specify a duration to avoid timeout
- @app.post(f"{base_url}/text2image")
- def generate_image(image_request: ImageRequest):
- # Clear CUDA cache and avoid tracking gradients
- torch.cuda.empty_cache()
- # Set the seed for reproducible results
- if image_request.seed is not None:
- torch.manual_seed(image_request.seed)
- torch.cuda.manual_seed(image_request.seed)
- np.random.seed(image_request.seed)
- width = 384
- height = 384
- parallel_size = image_request.num_images_per_prompt
-
- with torch.no_grad():
- messages = [{'role': '<|User|>', 'content': translation(image_request.prompt)},
- {'role': '<|Assistant|>', 'content': ''}]
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
- sft_format=vl_chat_processor.sft_format,
- system_prompt='')
- text = text + vl_chat_processor.image_start_tag
-
- input_ids = torch.LongTensor(tokenizer.encode(text))
- output, patches = generate(input_ids,
- width // 16 * 16,
- height // 16 * 16,
- cfg_weight=image_request.guidance_scale,
- parallel_size=parallel_size,
- temperature=1.0)
- images = unpack(patches,
- width // 16 * 16,
- height // 16 * 16,
- parallel_size=parallel_size)
- generated_images = [Image.fromarray(images[i]).resize((image_request.width, image_request.height), Image.LANCZOS) for i in range(parallel_size)]
-
- image_base64_list = []
- for generated_image in generated_images:
- buffered = BytesIO()
- generated_image.save(buffered, format="PNG")
- img_str = base64.b64encode(buffered.getvalue()).decode()
- image_base64_list.append(img_str)
-
- return {"image_base64_list": image_base64_list}
- #return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
-
-
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8888)
|