|
- import re
- from typing import Optional, List
- from langchain.llms.base import LLM
- from langchain.llms.utils import enforce_stop_tokens
- from pcl_pangu.online import Infer
- from pcl_pangu.context import set_context
- from pcl_pangu.model import alpha, evolution, mPangu
-
-
- class PclPanguMaas(LLM):
-
- model: str = ""
- api_key: str = ""
-
- def __init__(self, model: str, api_key: str):
- super().__init__()
- self.model = model
- self.api_key = api_key
-
- @property
- def _llm_type(self) -> str:
- return self.model
-
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
- prompt = re.sub("\n\n*", "\n", prompt)
- prompt = [prompt] if "chat" in self.model else prompt
- response = Infer.generate(self.model, prompt, self.api_key)
- ans = response.get("results", {}).get("generate_text", None)
- if stop is not None:
- ans = enforce_stop_tokens(ans, stop)
- return ans
-
-
- class PclPanguLocal(LLM):
-
- model_type: str = ""
- model: object = None
- config: object = None
-
- def __init__(self, model_type: str, model_name: str, model_path: str, backend: str):
- super().__init__()
- self.model_type = model_type
-
- set_context(backend)
-
- if "alpha"==model_type:
- self.model = alpha
- elif "evolution"==model_type:
- self.model = evolution
- elif "mPangu" == model_type:
- self.model = mPangu
- else:
- raise Exception("Error model type")
-
- if "onnx" in backend:
- self.config = self.model.model_config_onnx(model=model_name, load=model_path)
- elif "mindspore" == backend:
- self.config = self.model.model_config_npu(model=model_name, load=model_path)
- elif "pytorch" == backend:
- self.config = self.model.model_config_gpu(model=model_name, load=model_path)
- else:
- raise Exception("Error backend")
-
-
- @property
- def _llm_type(self) -> str:
- return self.model_type
-
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
- prompt = re.sub("\n\n*", "\n", prompt)
- response = self.model.inference(self.config, input=prompt)
- ans = response.get("results", {}).get("generate_text", None)
- if stop is not None:
- ans = enforce_stop_tokens(ans, stop)
- return ans
|