Add File
This commit is contained in:
167
backend/apps/ai_model/openai/llm.py
Normal file
167
backend/apps/ai_model/openai/llm.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import Dict, Optional, Any, Iterator, cast, Mapping
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import BaseMessage, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \
|
||||
SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import _create_usage_metadata
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
id_ = _dict.get("id")
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: dict = {}
|
||||
if 'reasoning_content' in _dict:
|
||||
additional_kwargs['reasoning_content'] = _dict.get('reasoning_content')
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=rtc["function"].get("name"),
|
||||
args=rtc["function"].get("arguments"),
|
||||
id=rtc.get("id"),
|
||||
index=rtc["index"],
|
||||
)
|
||||
for rtc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content, id=id_)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=id_,
|
||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
||||
)
|
||||
elif role in ("system", "developer") or default_class == SystemMessageChunk:
|
||||
if role == "developer":
|
||||
additional_kwargs = {"__openai_role__": "developer"}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
return SystemMessageChunk(
|
||||
content=content, id=id_, additional_kwargs=additional_kwargs
|
||||
)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(
|
||||
content=content, tool_call_id=_dict["tool_call_id"], id=id_
|
||||
)
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role, id=id_)
|
||||
else:
|
||||
return default_class(content=content, id=id_)
|
||||
|
||||
|
||||
class BaseChatOpenAI(ChatOpenAI):
|
||||
usage_metadata: dict = {}
|
||||
|
||||
# custom_get_token_ids = custom_get_token_ids
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.usage_metadata
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs['stream_usage'] = True
|
||||
for chunk in super()._stream(*args, **kwargs):
|
||||
if chunk.message.usage_metadata is not None:
|
||||
self.usage_metadata = chunk.message.usage_metadata
|
||||
yield chunk
|
||||
|
||||
def _convert_chunk_to_generation_chunk(
|
||||
self,
|
||||
chunk: dict,
|
||||
default_chunk_class: type,
|
||||
base_generation_info: Optional[dict],
|
||||
) -> Optional[ChatGenerationChunk]:
|
||||
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
|
||||
return None
|
||||
token_usage = chunk.get("usage")
|
||||
choices = (
|
||||
chunk.get("choices", [])
|
||||
# from beta.chat.completions.stream
|
||||
or chunk.get("chunk", {}).get("choices", [])
|
||||
)
|
||||
|
||||
usage_metadata: Optional[UsageMetadata] = (
|
||||
_create_usage_metadata(token_usage) if token_usage and token_usage.get("prompt_tokens") else None
|
||||
)
|
||||
if len(choices) == 0:
|
||||
# logprobs is implicitly None
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(content="", usage_metadata=usage_metadata)
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
choice = choices[0]
|
||||
if choice["delta"] is None:
|
||||
return None
|
||||
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {**base_generation_info} if base_generation_info else {}
|
||||
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if model_name := chunk.get("model"):
|
||||
generation_info["model_name"] = model_name
|
||||
if system_fingerprint := chunk.get("system_fingerprint"):
|
||||
generation_info["system_fingerprint"] = system_fingerprint
|
||||
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
if usage_metadata and isinstance(message_chunk, AIMessageChunk):
|
||||
message_chunk.usage_metadata = usage_metadata
|
||||
|
||||
generation_chunk = ChatGenerationChunk(
|
||||
message=message_chunk, generation_info=generation_info or None
|
||||
)
|
||||
return generation_chunk
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
config = ensure_config(config)
|
||||
chat_result = cast(
|
||||
"ChatGeneration",
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
|
||||
).message
|
||||
|
||||
self.usage_metadata = chat_result.response_metadata[
|
||||
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
|
||||
return chat_result
|
||||
Reference in New Issue
Block a user