2024. 6. 29. 23:08ㆍ관심있는 주제/LLM
LangGraph 가 나온 이유
LLM을 통해서 FLOW를 만들다 보면, 이전 LLM 결과에 대해서 다음 작업으로 넘길 때, LLM이 생성되는 결과에 의해 실패하는 경험들을 다들 해보셨을 것입니다.
개인적으로 분기 처리나 후처리 로직을 담는데, Output Parser에서 제어하려 했지만, 이러면 전반적인 구조나 결과를 보기가 쉽지 않았습니다.
LangGraph는 LangChain 생태계 내에서 이러한 문제를 직접 해결하기 위해 설계된 강력한 라이브러리입니다. 이 라이브러리는 여러 LLM 에이전트(또는 체인)를 구조화된 방식으로 정의, 조정 및 실행할 수 있는 프레임워크를 제공합니다.
LangGraph란 무엇인가요?
LangGraph는 LLM을 사용하여 상태를 유지하고 여러 에이전트를 포함한 애플리케이션을 쉽게 만들 수 있도록 도와줍니다. 이 도구는 LangChain의 기능을 확장하여, 복잡한 에이전트 런타임 개발에 필수적인 순환 그래프를 만들고 관리할 수 있는 기능을 추가합니다. LangGraph의 핵심 개념에는 그래프 구조, 상태 관리 및 조정이 포함됩니다.
Graph structure
LangGraph에서는 각 노드가 LLM 에이전트를 나타내고, edge는 이 에이전트들 간의 통신 채널입니다.
이 구조는 각 에이전트가 특정 작업을 수행하고 필요에 따라 다른 에이전트에 정보를 전달하는 명확하고 관리 가능한 워크플로를 허용합니다.
State management
LangGraph의 두드러진 특징 중 하나는 자동 상태 관리 기능입니다. 이 기능을 통해 여러 상호작용에 걸쳐 정보를 추적하고 유지할 수 있습니다. 에이전트가 작업을 수행함에 따라 상태가 동적으로 업데이트되어 시스템이 컨텍스트를 유지하고 새로운 입력에 적절히 반응할 수 있습니다.
이게 특히 좋은 것이 기존에 llm chain에서 하려고 하면 runnable 로 계속 감싸고 감싸는 구조를 만들었던 기억이 있는데, 이렇게 관리를 해준다는 게 너무 좋은 것 같습니다.
Coordination
LangGraph는 에이전트가 올바른 순서로 실행되고 필요한 정보가 원활하게 교환되도록 보장합니다. 이는 여러 에이전트가 협력하여 공동의 목표를 달성해야 하는 복잡한 애플리케이션에서 필수적입니다. 데이터 흐름과 작업 순서를 관리함으로써 LangGraph는 개발자가 에이전트 조정의 복잡한 세부 사항보다는 애플리케이션의 고수준 논리에 집중할 수 있게 합니다.
왜 써야 하는가?
Simplified development
LangGraph는 상태 관리 및 에이전트 조정과 관련된 복잡성을 추상화합니다. 따라서 개발자는 데이터 일관성과 올바른 실행 순서를 보장하는 기본 메커니즘에 대해 걱정하지 않고 워크플로와 로직을 정의할 수 있습니다. 이러한 간소화는 개발 속도를 높이고 오류 발생 가능성을 줄여줍니다. 정말 혁신적입니다!
아직 익숙하지 않지만, 익숙해지면 편리할 것 같은 구조인 것 같습니다.
Flexibility
LangGraph를 사용하면 개발자는 자신의 에이전트 로직과 통신 프로토콜을 정의할 수 있는 유연성을 갖게 됩니다. 이를 통해 특정 용도에 맞춘 맞춤형 애플리케이션을 쉽게 만들 수 있습니다. 다양한 유형의 사용자 요청을 처리할 수 있는 챗봇이 필요하든, 복잡한 작업을 수행하는 멀티 에이전트 시스템이 필요하든, LangGraph는 필요한 도구를 제공합니다. 창의력을 마음껏 발휘할 수 있습니다.
Scalability
LangGraph는 대규모 멀티 에이전트 애플리케이션의 실행을 지원하도록 설계되었습니다. 이 견고한 아키텍처는 대량의 상호작용과 복잡한 워크플로를 처리할 수 있어, 필요에 따라 확장 가능한 시스템을 개발할 수 있습니다. 이는 성능과 안정성이 중요한 엔터프라이즈 수준의 애플리케이션과 시나리오에 적합합니다.
Fault tolerance
LangGraph의 설계에서 신뢰성은 핵심 고려 사항입니다. 이 라이브러리는 오류를 우아하게 처리하는 메커니즘을 포함하고 있어, 개별 에이전트에 문제가 발생해도 애플리케이션이 계속 작동할 수 있습니다. 이러한 내결함성은 복잡한 멀티 에이전트 시스템의 안정성과 견고성을 유지하는 데 필수적입니다. 안심하고 사용할 수 있습니다.
아래에서는 2가지 예시로 해보고자 합니다.
langgraph라는 게 결국 workflow 간의 상태 관리를 하면서 flow를 더 쉽게 짜주는 프레임워크인 것 같습니다.
그래서 사용자가 꼭 langchain을 안써도 다른 workflow로 흐름을 제어할 것이 있다면, 해당 작업을 통해 매우 쉽게 만들 수 있을 것 같습니다.
그래서 1개는 간단하게 iris 데이터를 활용하여 분류 문제를 풀 때 특정 임계치까지 도달할 때까지 workflow가 돌게 하는 구조를 한번 만들어봤습니다.
다른 1개는 LLM 모델로 FLOW를 만드는 예시를 작성해보았습니다.
Building a Simple LangGraph Application (Train ML Model)
graph는 다음과 같이 구성해 봤습니다.
STEP | Action | Next Step |
Start | Initialize the workflow | retrieve_data |
retrieve_data | 데이터를 불러오고 전처리 | train_model |
train_model | Train the Decision Tree model 파라미터 변경된 걸로 재학습시키기 |
evaluate_model |
evaluate_model | Evaluate the model's F1 score (평가 score에 따라서 분기가 3개로 나감) |
visualize_performance (if F1 score > goal_threshold) |
adjust_parameters (if F1 score ≤ goal_threshold and iterations < 20) | ||
handle_stop_condition (if iterations ≥ 20) | ||
adjust_parameters | max_depth를 하나씩 증가시키 | train_model |
handle_stop_condition | Stop the workflow due to reaching the maximum iterations 정해놓은 run id 에서 결과를 찾지 못하면 종료 |
END |
visualize_performance | 모델 결과가 나오면 시각화 및 저장 | END |
END | End of the workflow | - |
한 가지 특이한 점은 recursion_limit의 1개의 실행 의미가 1개의 edge와 동일합니다
그래서 내가 만약 iteration을 20번 돌고 싶다면 해당 flow가 존재하는 edge의 개수만큼을 다 설정해야 합니다.
즉 recursion_limit을 일반적으로 생각하는 단위가 아니라 edge 단위로 계산해줘야 합니다.
한 가지 또 특이한 점은 만약 일반적으로 langgraph를 쓰고 싶다면 memorysaver를 사용해야 하는데, 그때 json 파일 형식만 가능하기 때문에 json으로 바꾸지 않는 경우 사용하지 못합니다.
config = RunnableConfig(recursion_limit=2 + 3 * 20 + 3, configurable={"thread_id": "THREAD_ID"})
GraphState 정의
아래는 다음과 같이 내가 workflow를 짤 동안 다른 node에 전달할 것과 기록하고 싶은 것들을 지정하는 부분이다.
이를 통해 현재 결과와 기록을 할 수 있습니다.
어떻게 보면 사용자는 Node들을 각자 개발하고 해당 개발 flow를 디자인해 주면 매우 쉽게 설계할 수 있을 것 같다는 생각이 들었습니다.
data | Optional[pd.DataFrame] | 훈련 및 테스트 모델에 사용될 데이터셋을 저장합니다. 전처리된 특성 열을 포함한 pandas DataFrame입니다. |
target | Optional[pd.Series] | 데이터셋의 타겟(레이블)을 저장합니다. pandas Series 형태입니다. |
params | Optional[dict] | 모델 훈련에 사용될 하이퍼파라미터를 저장합니다. 딕셔너리 형태로 키와 값으로 구성됩니다. |
model | Optional[DecisionTreeClassifier] | 훈련된 결정 트리 모델을 저장합니다. DecisionTreeClassifier 객체입니다. |
f1_score | Optional[float] | 모델의 평가 지표인 F1 점수를 저장합니다. float 값입니다. |
response | Optional[str] | 현재 상태에 대한 응답 메시지를 저장합니다. 주로 디버깅이나 로깅 목적으로 사용됩니다. |
goal_threshold | Optional[float] | 모델이 도달해야 하는 목표 F1 점수를 저장합니다. float 값입니다. |
X_test | Optional[pd.DataFrame] | 테스트 데이터셋의 특성 데이터를 저장합니다. pandas DataFrame 형태입니다. |
y_test | Optional[pd.Series] | 테스트 데이터셋의 타겟(레이블) 데이터를 저장합니다. pandas Series 형태입니다. |
history | Optional[list] | 각 반복마다 모델의 파라미터와 F1 점수를 저장한 이력 리스트입니다. |
iterations | Optional[int] | 모델 훈련 반복 횟수를 저장합니다. int 값입니다. |
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from typing import TypedDict, Optional
import matplotlib.pyplot as plt
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.sqlite import SqliteSaver
# Define the GraphState class
class GraphState(TypedDict):
data: Optional[pd.DataFrame] = None
target: Optional[pd.Series] = None
params: Optional[dict] = None
model: Optional[DecisionTreeClassifier] = None
f1_score: Optional[float] = None
response: Optional[str] = None
goal_threshold: Optional[float] = None
X_test: Optional[pd.DataFrame] = None
y_test: Optional[pd.Series] = None
history: Optional[list] = None
iterations: Optional[int] = None
# Load and preprocess the data
def retrieve_data(state: GraphState) -> GraphState:
url = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'
df = pd.read_csv(url)
# Preprocessing steps
df.drop(columns=['Name', 'Ticket', 'Cabin'], inplace=True)
df['Age'].fillna(df['Age'].median(), inplace=True)
df['Embarked'].fillna(df['Embarked'].mode()[0], inplace=True)
# Encode categorical variables
le = LabelEncoder()
df['Sex'] = le.fit_transform(df['Sex'])
df['Embarked'] = le.fit_transform(df['Embarked'])
target = df['Survived']
df.drop(columns=['Survived'], inplace=True)
return {**state, "data": df, "target": target, "params": {"max_depth": 2}, "history": [], "iterations": 0}
# Train the model
def train_model(state: GraphState) -> GraphState:
data = state["data"]
target = state["target"]
params = state["params"]
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=42)
model = DecisionTreeClassifier(**params)
model.fit(X_train, y_train)
return {**state, "model": model, "X_test": X_test, "y_test": y_test}
# Evaluate the model
def evaluate_model(state: GraphState) -> GraphState:
model = state["model"]
X_test = state["X_test"]
y_test = state["y_test"]
y_pred = model.predict(X_test)
f1_score_value = f1_score(y_test, y_pred, average='weighted')
history = state["history"]
history.append({"params": state["params"].copy(), "f1_score": f1_score_value})
return {**state, "f1_score": f1_score_value, "response": f"Model F1 score: {f1_score_value:.2f}", "history": history}
# Adjust parameters
def adjust_parameters(state: GraphState) -> GraphState:
params = state["params"]
params["max_depth"] += 1
iterations = state["iterations"] + 1
print(f"Adjusting parameters: Iteration {iterations}, Max Depth {params['max_depth']}")
return {**state, "params": params, "iterations": iterations}
# Check if the accuracy is sufficient
def is_sufficient_accuracy(state: GraphState) -> str:
goal_threshold = state["goal_threshold"]
if state["f1_score"] and state["f1_score"] > goal_threshold:
return "sufficient"
if state["iterations"] and state["iterations"] >= 20:
return "stop"
return "insufficient"
# Handle stop condition
def handle_stop_condition(state: GraphState) -> GraphState:
return {**state, "response": "Stopping the workflow due to reaching the maximum iterations."}
# Visualize the performance
def visualize_performance(state: GraphState) -> GraphState:
history = state["history"]
depths = [entry["params"]["max_depth"] for entry in history]
f1_scores = [entry["f1_score"] for entry in history]
plt.figure(figsize=(10, 6))
plt.plot(depths, f1_scores, marker='o', linestyle='-', color='b')
plt.title('Model Performance vs. Max Depth')
plt.xlabel('Max Depth')
plt.ylabel('F1 Score')
plt.grid(True)
plot_filename = 'model_performance.png'
plt.savefig(plot_filename)
plt.close()
return {**state, "response": f"Model training complete. Performance plot saved as '{plot_filename}'", "plot_filename": plot_filename}
# Initialize the state graph
workflow = StateGraph(GraphState)
# Add nodes to the workflow
workflow.add_node("retrieve_data", retrieve_data)
workflow.add_node("train_model", train_model)
workflow.add_node("evaluate_model", evaluate_model)
workflow.add_node("adjust_parameters", adjust_parameters)
workflow.add_node("handle_stop_condition", handle_stop_condition)
workflow.add_node("visualize_performance", visualize_performance)
# Add edges to the workflow
workflow.add_edge("retrieve_data", "train_model")
workflow.add_edge("train_model", "evaluate_model")
# Add conditional edges for retraining if accuracy is insufficient
workflow.add_conditional_edges(
"evaluate_model",
is_sufficient_accuracy,
{
"sufficient": "visualize_performance",
"insufficient": "adjust_parameters",
"stop": "handle_stop_condition"
}
)
workflow.add_edge("adjust_parameters", "train_model")
workflow.add_edge("handle_stop_condition", END)
workflow.add_edge("visualize_performance", END)
# Set entry point
workflow.set_entry_point("retrieve_data")
from langchain_core.runnables import RunnableConfig
# Increase the recursion limit
config = RunnableConfig(recursion_limit=2 + 3 * 20 + 3, configurable={"thread_id": "THREAD_ID"})
# Initialize SqliteSaver (사용 못함)
# memory = SqliteSaver.from_conn_string(":memory:")
# Compile the workflow with checkpointing
app = workflow.compile()
# Example input with goal_threshold
inputs = {
"goal_threshold": 0.95
}
result = app.invoke(inputs, config=config)
# Print the final result
print(result)
이런 식으로 결과를 저장하는 것을 할 수 있었습니다.
Building a Simple LangGraph Application (LLM Simple Chat)
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langchain_core.runnables import RunnableLambda
from langchain.llms.fake import FakeListLLM
from langchain.schema import (
AIMessage,
# HumanMessage,
# SystemMessage
)
from langchain_openai import ChatOpenAI
res = ["Action: python_repl_ast\nAction Input: print(2.2 + 2.22)", "Final Answer: 4.42"]
llm = FakeListLLM(responses=res) | RunnableLambda(lambda x: AIMessage(x))
# llm = ChatOpenAI()
def chatbot(state: State):
return {"messages": [llm.invoke(state["messages"])]}
class State(TypedDict):
# messages have the type "list".
# The add_messages function appends messages to the list, rather than overwriting them
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
graph_builder.add_node("chatbot", chatbot)
graph_builder.set_entry_point("chatbot")
graph_builder.set_finish_point("chatbot")
graph = graph_builder.compile()
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
pass
위의 예시는 간단하게 FakeListLLM으로 llm을 만들고 chat이라는 node를 만들고, 시작과 끝을 지정하는 식으로 해봤습니다
# Run the chatbot
while True:
user_input = input("User: ")
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
for event in graph.stream({"messages": [("user", user_input)]}):
for value in event.values():
print("Assistant:", value["messages"][-1].content)
Building a Simple LangGraph Application (LLM Simple Chat + Tool)
chat과 tool을 같이 사용해 봤습니다.
tool을 사용하는 것에 있어서는 이게 정답은 아닌 것 같고, ToolNode를 받아서 사용하는 것 같지만, 그것은 기존에 python tool 실행하는 거랑 하려다 보니 아직 잘 이해가 안돼서 좀 해봐야될 것 같습니다.
from typing import Annotated
from langgraph.graph import StateGraph , END
from langgraph.graph import END, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from typing import Annotated, Literal, TypedDict
from langchain_experimental.tools import PythonAstREPLTool
from langchain.llms.fake import FakeListLLM
from langchain_core.runnables import RunnableLambda
from langchain.schema import (
AIMessage,
HumanMessage,
# SystemMessage
)
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_experimental.tools import PythonAstREPLTool
from langchain.pydantic_v1 import BaseModel, Field
class SearchInput(BaseModel):
query: str = Field(description="should be a search query")
@tool("search-tool", args_schema=SearchInput, return_direct=True)
def search(query: str) -> str:
"""Look up things online."""
return "LangChain"
tools = [PythonAstREPLTool(), search]
tool_dict = {} # this is going to be required during tool execution
for tool in tools:
tool_dict[tool.name]= tool
# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state['messages']
last_message = messages[-1]
# If there is no function call, then we finish
if "function_call" not in last_message.additional_kwargs:
return END
# Otherwise if there is, we continue
else:
return "tools"
# Define the function that calls the model
class MyCustomNode:
def __init__(self, llm):
self.llm = llm
def __call__(self, state):
# Implement your custom logic here
# Access the state and perform actions
messages = state["messages"]
response = self.llm.invoke(messages)
return {"messages": [response]}
def tool_execution(state):
tool_name = state['messages'][-1].additional_kwargs['function_call']['name']
args = state['messages'][-1].additional_kwargs['function_call']['arguments']
# kwargs to args
result = tool_dict[tool_name](*args)
return {'messages': [AIMessage(content=result)]}
def responder(state):
return state
def route(state) :
print(state['messages'][-1] , len(state['messages'][-1].content))
if len(state['messages'][-1].content) != 0 :
return 'respond'
else :
return 'agent'
# llm = ChatAnthropic(model="claude-3-haiku-20240307")
res = ["print(2.2 + 2.22)", "print(4.42*5)"]
llm = FakeListLLM(responses=res) | RunnableLambda(lambda x: AIMessage(content="", additional_kwargs=dict(
function_call=dict(
name='python_repl_ast',
arguments=(x,)
)
)))
graph_builder = StateGraph(MessagesState)
graph_builder.add_node("agent", MyCustomNode(llm))
graph_builder.add_node("tools", tool_execution)
graph_builder.add_node("respond", responder)
graph_builder.set_entry_point("agent")
graph_builder.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
END: END # Correctly reference END here
}
)
graph_builder.add_conditional_edges("tools", route, {"respond":"respond", "agent":"agent"})
graph_builder.add_edge("respond", END)
# Connect to the SQLite database
memory = SqliteSaver.from_conn_string(":memory:")
# Compile the graph with the checkpointer
app = graph_builder.compile(checkpointer=memory)
app.invoke(
{"messages": [HumanMessage(content="what is 2.2 + 2.22?")]},
config={"configurable": {"thread_id": 1111}}
)
a
Building a Simple LangGraph Application (Parallel Branch)
import operator
from typing import Annotated, Sequence , Any
from typing_extensions import TypedDict
from langchain.llms.fake import FakeListLLM
from langgraph.graph import StateGraph
import operator
from typing import Annotated, Sequence
from typing import Annotated, Any
from typing_extensions import TypedDict
from langgraph.graph import END, START, StateGraph
import operator
from langgraph.graph import StateGraph
def reduce_fanouts(left, right):
if left is None:
left = []
if not right:
# Overwrite
return []
return left + right
class State(TypedDict):
# The operator.add reducer fn makes this append-only
aggregate: Annotated[list, operator.add]
fanout_values: Annotated[list, reduce_fanouts]
which: str
class ReturnNodeValue:
def __init__(self, node_secret: str):
self._value = node_secret
def __call__(self, state: State) -> Any:
print(f"Adding {self._value} to {state['aggregate']}")
return {"aggregate": [self._value]}
def aggregate_fanout_values(state: State) -> Any:
# Sort by reliability
ranked_values = sorted(
state["fanout_values"], key=lambda x: x["reliability"], reverse=True
)
return {
"aggregate": [x["value"] for x in ranked_values] + ["I'm E"],
"fanout_values": [],
}
def route_bc_or_cd(state: State) -> Sequence[str]:
return state["which"].split(",")
class ParallelReturnNodeValue:
def __init__(
self,
llm: FakeListLLM,
reliability: float,
):
self._llm = llm
self._reliability = reliability
def __call__(self, state: State) -> Any:
self._value = self._llm.invoke("??")
print(self._value)
print(f"Adding {self._value} to {state['aggregate']} in parallel.")
return {
"fanout_values": [
{
"value": [self._value],
"reliability": self._reliability,
}
]
}
builder = StateGraph(State)
builder.add_node("a", ReturnNodeValue("I'm A"))
builder.set_entry_point("a")
res = ["Action: python_repl_ast\nAction Input: print(2.2 + 2.22)", "Final Answer: 4.42"]
llm1 = FakeListLLM(responses=res)
res = ['hi', 'hello', 'howdy', 'hey']
llm2 = FakeListLLM(responses=res)
res = ['bye', 'goodbye', 'see ya', 'later', 'peace']
llm3 = FakeListLLM(responses=res)
builder.add_node("b", ParallelReturnNodeValue(llm1, reliability=0.9))
builder.add_node("c", ParallelReturnNodeValue(llm2, reliability=0.1))
builder.add_node("d", ParallelReturnNodeValue(llm3, reliability=0.3))
builder.add_node("e", aggregate_fanout_values)
intermediates = ["b", "c", "d"]
builder.add_conditional_edges("a", route_bc_or_cd, intermediates)
for node in intermediates:
builder.add_edge(node, "e")
builder.set_finish_point("e")
graph = builder.compile()
from IPython.display import Image, display
display(Image(graph.get_graph().draw_mermaid_png()))
result = graph.invoke({"aggregate": [], "which": "b,d", "fanout_values": []})
result
결론
langgraph에 대해서 알아보았고, 아직 먼가 익숙하지 않은 부분이 있지만 익숙해지면 workflow로 만들 때 매우 좋을 것 같고 특히 cycle이 있는 것을 쉽게 구현할 수 있어 좋아보입니다.
참고
'관심있는 주제 > LLM' 카테고리의 다른 글
Layout LM(=Language Model) 알아보기 - TODO (1) | 2024.07.23 |
---|---|
논문 정리) Searching for Best Practices in Retrieval-Augmented Generation (0) | 2024.07.05 |
Advanced RAG - 질문 유형 및 다양한 질문 유형을 위한 방법론(Ranker) (1) | 2024.06.22 |
LLM) HuggingFace 에 사용하는 Tokenizer 의 결과 비교하는 Streamlit APP (0) | 2024.06.01 |
LLM) Quantization 방법론 알아보기 (GPTQ | QAT | AWQ | GGUF | GGML | PTQ) (0) | 2024.04.29 |