|
- '''gcu只选一个onnx的模型'''
- # notebook发送请求示例:!curl -X POST http://0.0.0.0:8887/text2image -H "Content-Type: application/json" -d '{"prompt": "cute dragon creature", "negative_prompt": "nsfw"}'
- from fastapi import FastAPI
- import os
- import shutil
- import base64
- from PIL import Image
- from pydantic import BaseModel
-
- app = FastAPI()
-
- from c2net.context import prepare
-
- #初始化导入数据集和预训练模型到容器内
- c2net_context = prepare()
-
- base_url = os.getenv('OCTOPUS_NOTEBOOK_BASE_URL')
-
- import subprocess
-
- sdtraining_model_shuimo_onnx_path = c2net_context.pretrain_model_path+"/"+"sdtraining_model_ShuiMo—onnx"
- model_path = sdtraining_model_shuimo_onnx_path + "/safetensor/512x512/"
- class Text2ImageRequest(BaseModel):
- model_base: str = "/tmp/pretrainmodel/sdtraining_model_ShuiMo—onnx/safetensor/512x512/"
- gcu_results: str = "/tmp/code/results/text2image"
- num_images_per_prompt: int = 4
- prompt: str = "cute dragon creature"
- negative_prompt: str = "nsfw"
- gcu: int = 0
-
- @app.post(f"{base_url}/text2image")
- async def generate_image(request: Text2ImageRequest):
- model_base = model_path
- gcu_results = request.gcu_results
- image_num = request.num_images_per_prompt
- prompt = request.prompt
- negative_prompt = request.negative_prompt
- gcu = request.gcu
-
- command = [
- "python3",
- "-m",
- "stable_diffusion.examples.stable_diffusion.demo_text2image_topsinference",
- "--model",
- model_base,
- "--image_num",
- str(image_num),
- "--prompt",
- prompt,
- "--negative_prompt",
- negative_prompt,
- "--output",
- gcu_results,
- "--gcu",
- str(gcu),
- "--model_type",
- "sd_v1_5",
- "--image_height",
- "512",
- "--image_width",
- "512",
- "--platform",
- "general",
- "--scheduler",
- "ddim",
- "--denoising_steps",
- "20"
- ]
-
- image_base64_list = []
- images_folder_path = gcu_results
- if os.path.exists(images_folder_path):
- shutil.rmtree(images_folder_path)
- os.makedirs(images_folder_path)
- subprocess.run(command)
-
- file_list = os.listdir(images_folder_path)
- for file_name in file_list:
- file_path = os.path.join(images_folder_path, file_name)
- if file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
- with open(file_path, "rb") as image_file:
- image_base64 = base64.b64encode(image_file.read()).decode()
- image_base64_list.append(image_base64)
-
- return {"image_base64_list": image_base64_list}
-
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8888)
|