Goose 实现细节
本章概览
本章将深入分析 Goose 的关键实现细节,包括:
- Agent 核心实现
- Provider 系统实现
- Extension Manager 实现
- 工具检查和安全机制
- 上下文压缩策略
1. Agent 核心实现
1.1 Agent 创建
rust
// 位于 crates/goose/src/agents/agent.rs
impl Agent {
pub fn new() -> Self {
// 创建通信通道
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
// 创建共享 Provider 引用
let provider = Arc::new(Mutex::new(None));
Self {
provider: provider.clone(),
extension_manager: Arc::new(ExtensionManager::new(provider.clone())),
sub_recipes: Mutex::new(HashMap::new()),
final_output_tool: Arc::new(Mutex::new(None)),
frontend_tools: Mutex::new(HashMap::new()),
frontend_instructions: Mutex::new(None),
prompt_manager: Mutex::new(PromptManager::new()),
confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
scheduler_service: Mutex::new(None),
retry_manager: RetryManager::new(),
tool_inspection_manager: Self::create_default_tool_inspection_manager(),
}
}
}1.2 工具检查管理器
Agent 使用多层检查器确保安全:
rust
fn create_default_tool_inspection_manager() -> ToolInspectionManager {
let mut tool_inspection_manager = ToolInspectionManager::new();
// 1. 安全检查器(最高优先级)
tool_inspection_manager.add_inspector(
Box::new(SecurityInspector::new())
);
// 2. 权限检查器
tool_inspection_manager.add_inspector(
Box::new(PermissionInspector::new(
GooseMode::SmartApprove,
HashSet::new(), // readonly tools
HashSet::new(), // regular tools
))
);
// 3. 重复检查器(防止无限循环)
tool_inspection_manager.add_inspector(
Box::new(RepetitionInspector::new(None))
);
tool_inspection_manager
}1.3 Reply 方法(交互循环核心)
rust
// 简化版实现
impl Agent {
pub async fn reply(
&self,
conversation: &mut Conversation,
cancel_token: CancellationToken,
) -> Result<BoxStream<AgentEvent>> {
let tools = self.extension_manager.all_tools().await;
let system_prompt = self.build_system_prompt().await;
// 创建回复上下文
let context = ReplyContext {
conversation: conversation.clone(),
tools,
system_prompt,
goose_mode: self.config.goose_mode,
};
// 开始交互循环
self.reply_loop(context, cancel_token).await
}
async fn reply_loop(
&self,
mut context: ReplyContext,
cancel_token: CancellationToken,
) -> Result<BoxStream<AgentEvent>> {
loop {
// 检查取消令牌
if cancel_token.is_cancelled() {
break;
}
// 1. 调用 LLM
let response = self.provider_chat(&context).await?;
// 2. 发送消息事件
yield AgentEvent::Message(response.clone());
// 3. 检查是否有工具调用
if response.tool_calls.is_empty() {
break; // 没有工具调用,结束循环
}
// 4. 执行工具调用
for tool_call in &response.tool_calls {
// 4.1 工具检查
let check_result = self.tool_inspection_manager
.check_tool(tool_call)
.await?;
match check_result {
PermissionCheckResult::Approved => {
// 4.2 执行工具
let result = self.execute_tool(tool_call).await;
// 4.3 添加结果到对话
context.conversation.add_tool_result(
&tool_call.id,
result,
);
}
PermissionCheckResult::Denied(reason) => {
context.conversation.add_tool_result(
&tool_call.id,
ToolResult::error(reason),
);
}
PermissionCheckResult::NeedsConfirmation => {
// 等待用户确认
let confirmed = self.wait_for_confirmation(&tool_call).await?;
// ...
}
}
}
}
Ok(stream)
}
}2. Provider 实现
2.1 OpenAI Provider
rust
// 位于 crates/goose/src/providers/openai.rs
pub struct OpenAIProvider {
client: reqwest::Client,
api_key: String,
base_url: String,
model: String,
model_info: ModelInfo,
}
impl OpenAIProvider {
pub fn new(model: &str, config: &Config) -> Result<Self> {
let api_key = config.get_secret("OPENAI_API_KEY")?;
let base_url = config.get("OPENAI_API_BASE")
.unwrap_or("https://api.openai.com/v1".to_string());
// 获取模型信息
let model_info = Self::get_model_info(model)?;
Ok(Self {
client: reqwest::Client::new(),
api_key,
base_url,
model: model.to_string(),
model_info,
})
}
}
#[async_trait]
impl Provider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
fn model_info(&self) -> &ModelInfo {
&self.model_info
}
async fn chat(
&self,
messages: &[Message],
tools: &[Tool],
) -> Result<Message, ProviderError> {
// 1. 转换消息格式
let openai_messages = self.convert_messages(messages);
// 2. 转换工具格式
let openai_tools = self.convert_tools(tools);
// 3. 构建请求
let request = ChatCompletionRequest {
model: &self.model,
messages: openai_messages,
tools: Some(openai_tools),
temperature: Some(0.7),
max_tokens: Some(4096),
};
// 4. 发送请求
let response = self.client
.post(format!("{}/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await
.map_err(|e| ProviderError::NetworkError(e.to_string()))?;
// 5. 处理响应
let result: ChatCompletionResponse = response
.json()
.await
.map_err(|e| ProviderError::Unknown(e.to_string()))?;
// 6. 转换回 Goose Message 格式
self.convert_response(result)
}
}2.2 Ollama Provider(本地模型)
rust
// 位于 crates/goose/src/providers/ollama.rs
pub struct OllamaProvider {
client: reqwest::Client,
base_url: String,
model: String,
}
impl OllamaProvider {
pub fn new(model: &str, config: &Config) -> Result<Self> {
let base_url = config.get("OLLAMA_HOST")
.unwrap_or("http://localhost:11434".to_string());
Ok(Self {
client: reqwest::Client::new(),
base_url,
model: model.to_string(),
})
}
}
#[async_trait]
impl Provider for OllamaProvider {
async fn chat(
&self,
messages: &[Message],
tools: &[Tool],
) -> Result<Message, ProviderError> {
let request = OllamaChatRequest {
model: &self.model,
messages: self.convert_messages(messages),
tools: Some(self.convert_tools(tools)),
stream: false,
};
let response = self.client
.post(format!("{}/api/chat", self.base_url))
.json(&request)
.send()
.await?;
// 处理响应...
}
}2.3 Provider 工厂
rust
// 位于 crates/goose/src/providers/factory.rs
pub fn create(
provider_name: &str,
model_name: &str,
config: &Config,
) -> Result<Box<dyn Provider>> {
match provider_name.to_lowercase().as_str() {
"openai" => Ok(Box::new(OpenAIProvider::new(model_name, config)?)),
"anthropic" => Ok(Box::new(AnthropicProvider::new(model_name, config)?)),
"ollama" => Ok(Box::new(OllamaProvider::new(model_name, config)?)),
"azure" => Ok(Box::new(AzureProvider::new(model_name, config)?)),
"bedrock" => Ok(Box::new(BedrockProvider::new(model_name, config)?)),
"gcpvertexai" => Ok(Box::new(GCPVertexAIProvider::new(model_name, config)?)),
"openrouter" => Ok(Box::new(OpenRouterProvider::new(model_name, config)?)),
"litellm" => Ok(Box::new(LiteLLMProvider::new(model_name, config)?)),
_ => Err(anyhow!("Unknown provider: {}", provider_name)),
}
}3. Extension Manager 实现
3.1 ExtensionManager 结构
rust
// 位于 crates/goose/src/agents/extension_manager.rs
pub struct ExtensionManager {
// 内置扩展
builtin_extensions: HashMap<String, Arc<dyn Extension>>,
// MCP 客户端
mcp_clients: RwLock<HashMap<String, McpClient>>,
// 扩展配置
extension_configs: RwLock<Vec<ExtensionConfig>>,
// Provider 引用(用于子 Agent)
provider: SharedProvider,
}3.2 注册扩展
rust
impl ExtensionManager {
pub async fn register(&self, config: ExtensionConfig) -> Result<()> {
match &config.extension_type {
ExtensionType::Builtin(name) => {
// 内置扩展已预先注册
self.enable_builtin(name)?;
}
ExtensionType::Stdio(command, args) => {
// 启动 MCP Server 进程
let client = McpClient::spawn(
command,
args,
config.env.clone(),
).await?;
// 初始化连接
client.initialize().await?;
// 保存客户端
self.mcp_clients.write().await
.insert(config.name.clone(), client);
}
ExtensionType::Sse(url) => {
// 连接到 SSE 端点
let client = McpClient::connect_sse(url).await?;
client.initialize().await?;
self.mcp_clients.write().await
.insert(config.name.clone(), client);
}
}
// 保存配置
self.extension_configs.write().await.push(config);
Ok(())
}
}3.3 获取所有工具
rust
impl ExtensionManager {
pub async fn all_tools(&self) -> Vec<Tool> {
let mut tools = Vec::new();
// 收集内置扩展的工具
for extension in self.builtin_extensions.values() {
tools.extend(extension.tools().iter().cloned());
}
// 收集 MCP 扩展的工具
let clients = self.mcp_clients.read().await;
for client in clients.values() {
if let Ok(mcp_tools) = client.list_tools().await {
tools.extend(mcp_tools);
}
}
tools
}
}3.4 调用工具
rust
impl ExtensionManager {
pub async fn call_tool(
&self,
tool_name: &str,
params: Value,
) -> ToolResult<CallToolResult> {
// 1. 查找工具所属的扩展
let extension_name = self.find_extension_for_tool(tool_name).await?;
// 2. 调用对应的扩展
if let Some(extension) = self.builtin_extensions.get(&extension_name) {
// 内置扩展
extension.call_tool(tool_name, params.as_object().unwrap().clone()).await
} else if let Some(client) = self.mcp_clients.read().await.get(&extension_name) {
// MCP 扩展
client.call_tool(tool_name, params).await
} else {
Err(ErrorData {
code: ErrorCode::METHOD_NOT_FOUND,
message: format!("Tool not found: {}", tool_name).into(),
data: None,
})
}
}
}4. MCP 客户端实现
4.1 McpClient 结构
rust
// 位于 crates/goose/src/agents/mcp_client.rs
pub struct McpClient {
// 传输层
transport: Arc<dyn McpTransport>,
// 服务器能力
capabilities: Option<ServerCapabilities>,
// 工具缓存
tools_cache: RwLock<Option<Vec<Tool>>>,
// 请求 ID 计数器
request_id: AtomicU64,
}4.2 初始化连接
rust
impl McpClient {
pub async fn initialize(&self) -> Result<ServerCapabilities> {
// 发送初始化请求
let response = self.send_request(
"initialize",
json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "goose",
"version": env!("CARGO_PKG_VERSION")
}
}),
).await?;
// 解析服务器能力
let capabilities: ServerCapabilities = serde_json::from_value(response)?;
// 发送 initialized 通知
self.send_notification("notifications/initialized", json!({})).await?;
Ok(capabilities)
}
}4.3 工具调用
rust
impl McpClient {
pub async fn call_tool(
&self,
name: &str,
arguments: Value,
) -> ToolResult<CallToolResult> {
let response = self.send_request(
"tools/call",
json!({
"name": name,
"arguments": arguments
}),
).await?;
// 解析结果
let result: CallToolResult = serde_json::from_value(response)?;
if result.is_error.unwrap_or(false) {
Err(ErrorData::from_result(&result))
} else {
Ok(result)
}
}
}5. 安全检查实现
5.1 SecurityInspector
rust
// 位于 crates/goose/src/security/security_inspector.rs
pub struct SecurityInspector {
// 危险命令模式
dangerous_patterns: Vec<Regex>,
}
impl SecurityInspector {
pub fn new() -> Self {
let patterns = vec![
Regex::new(r"rm\s+-rf\s+/").unwrap(), // 删除根目录
Regex::new(r"mkfs\.").unwrap(), // 格式化
Regex::new(r"dd\s+if=.*of=/dev/").unwrap(), // 直接写入设备
Regex::new(r":(){:\|:&};:").unwrap(), // Fork 炸弹
// ... 更多模式
];
Self { dangerous_patterns: patterns }
}
}
impl ToolInspector for SecurityInspector {
fn priority(&self) -> u32 {
100 // 最高优先级
}
async fn check(
&self,
tool_call: &ToolCall,
_context: &InspectionContext,
) -> Result<InspectionResult> {
// 检查 shell 命令
if tool_call.name == "shell" {
if let Some(command) = tool_call.arguments.get("command") {
let cmd = command.as_str().unwrap_or("");
for pattern in &self.dangerous_patterns {
if pattern.is_match(cmd) {
return Ok(InspectionResult::Block(
format!("Blocked dangerous command: {}", cmd)
));
}
}
}
}
Ok(InspectionResult::Allow)
}
}5.2 PermissionInspector
rust
// 位于 crates/goose/src/permission/permission_inspector.rs
pub struct PermissionInspector {
mode: GooseMode,
readonly_tools: HashSet<String>,
regular_tools: HashSet<String>,
}
impl ToolInspector for PermissionInspector {
async fn check(
&self,
tool_call: &ToolCall,
context: &InspectionContext,
) -> Result<InspectionResult> {
match self.mode {
GooseMode::AutoApprove => {
// 自动批准所有工具
Ok(InspectionResult::Allow)
}
GooseMode::AskEveryTime => {
// 每次都需要确认
Ok(InspectionResult::NeedsConfirmation)
}
GooseMode::ChatOnly => {
// 禁止所有工具
Ok(InspectionResult::Block("Chat-only mode".to_string()))
}
GooseMode::SmartApprove => {
// 智能判断
if self.readonly_tools.contains(&tool_call.name) {
// 只读工具自动批准
Ok(InspectionResult::Allow)
} else if self.is_safe_operation(tool_call) {
// 安全操作自动批准
Ok(InspectionResult::Allow)
} else {
// 需要确认
Ok(InspectionResult::NeedsConfirmation)
}
}
}
}
}5.3 RepetitionInspector
rust
// 位于 crates/goose/src/tool_monitor.rs
pub struct RepetitionInspector {
// 最近的工具调用记录
recent_calls: Mutex<VecDeque<ToolCallRecord>>,
// 最大记录数
max_records: usize,
// 重复阈值
repetition_threshold: usize,
}
impl ToolInspector for RepetitionInspector {
async fn check(
&self,
tool_call: &ToolCall,
_context: &InspectionContext,
) -> Result<InspectionResult> {
let mut recent = self.recent_calls.lock().await;
// 计算相同调用的次数
let count = recent.iter()
.filter(|r| r.name == tool_call.name && r.args == tool_call.arguments)
.count();
// 添加当前调用
recent.push_back(ToolCallRecord {
name: tool_call.name.clone(),
args: tool_call.arguments.clone(),
timestamp: Instant::now(),
});
// 保持记录数量限制
while recent.len() > self.max_records {
recent.pop_front();
}
// 检查重复
if count >= self.repetition_threshold {
Ok(InspectionResult::Block(format!(
"Tool {} has been called {} times with same arguments",
tool_call.name, count
)))
} else {
Ok(InspectionResult::Allow)
}
}
}6. 上下文压缩
6.1 为什么需要压缩?
LLM 有上下文长度限制,长对话会超出限制。Goose 实现了自动压缩机制。
6.2 压缩策略
rust
// 位于 crates/goose/src/context_mgmt/mod.rs
pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.7; // 70% 触发压缩
pub async fn check_if_compaction_needed(
conversation: &Conversation,
model_info: &ModelInfo,
) -> bool {
let current_tokens = count_tokens(conversation);
let threshold = (model_info.context_limit as f64 * DEFAULT_COMPACTION_THRESHOLD) as usize;
current_tokens > threshold
}
pub async fn compact_messages(
conversation: &mut Conversation,
provider: &dyn Provider,
) -> Result<()> {
// 1. 保留最近的消息
let recent_messages = conversation.messages()
.iter()
.rev()
.take(KEEP_RECENT_COUNT)
.cloned()
.collect::<Vec<_>>();
// 2. 压缩历史消息
let old_messages = conversation.messages()
.iter()
.take(conversation.len() - KEEP_RECENT_COUNT)
.cloned()
.collect::<Vec<_>>();
// 3. 使用 LLM 生成摘要
let summary = generate_summary(provider, &old_messages).await?;
// 4. 替换对话历史
conversation.clear();
conversation.add_system_message(&format!(
"Previous conversation summary:\n{}",
summary
));
for msg in recent_messages.into_iter().rev() {
conversation.add_message(msg);
}
Ok(())
}7. 流式输出
7.1 Stream 处理
rust
// 位于 crates/goose/src/agents/agent.rs
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem> + Send>>;
pub enum ToolStreamItem<T> {
Message(ServerNotification), // MCP 通知
Result(T), // 最终结果
}
// 合并通知流和结果
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
where
S: Stream<Item = ServerNotification> + Send + Unpin + 'static,
F: Future<Output = ToolResult<CallToolResult>> + Send + 'static,
{
Box::pin(async_stream::stream! {
tokio::pin!(done);
let mut rx = rx;
loop {
tokio::select! {
// 处理通知
Some(msg) = rx.next() => {
yield ToolStreamItem::Message(msg);
}
// 处理结果
r = &mut done => {
yield ToolStreamItem::Result(r);
break;
}
}
}
})
}7.2 Provider 流式响应
rust
impl Provider for OpenAIProvider {
async fn chat_stream(
&self,
messages: &[Message],
tools: &[Tool],
) -> Result<BoxStream<StreamEvent>> {
let request = /* ... */;
let response = self.client
.post(&self.endpoint)
.json(&request)
.send()
.await?;
// 处理 SSE 流
let stream = response
.bytes_stream()
.map_err(|e| ProviderError::NetworkError(e.to_string()))
.and_then(|bytes| async move {
// 解析 SSE 事件
parse_sse_event(&bytes)
});
Ok(Box::pin(stream))
}
}8. 错误恢复
8.1 Provider 重试
rust
// 位于 crates/goose/src/providers/retry.rs
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
}
pub async fn with_retry<F, Fut, T>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, ProviderError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, ProviderError>>,
{
let mut attempts = 0;
let mut delay = config.initial_delay;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(ProviderError::RateLimited { retry_after }) => {
// 速率限制,等待后重试
let wait_time = retry_after.unwrap_or(delay);
tokio::time::sleep(wait_time).await;
}
Err(ProviderError::ServerError(_)) if attempts < config.max_retries => {
// 服务器错误,指数退避重试
tokio::time::sleep(delay).await;
delay = Duration::from_secs_f64(
(delay.as_secs_f64() * config.multiplier).min(config.max_delay.as_secs_f64())
);
}
Err(e) => return Err(e),
}
attempts += 1;
}
}8.2 工具错误处理
rust
// 工具错误会被发回 LLM
impl Agent {
async fn handle_tool_error(
&self,
tool_call: &ToolCall,
error: ErrorData,
conversation: &mut Conversation,
) {
// 构造错误消息
let error_message = format!(
"Tool '{}' failed: {}\n\nPlease try a different approach.",
tool_call.name,
error.message
);
// 添加到对话
conversation.add_tool_result(
&tool_call.id,
CallToolResult {
content: vec![Content::Text { text: error_message }],
is_error: Some(true),
},
);
// LLM 会看到这个错误并尝试恢复
}
}9. 性能优化
9.1 工具缓存
rust
// 工具列表缓存
pub struct ExtensionManager {
tools_cache: RwLock<Option<Vec<Tool>>>,
cache_invalidated: AtomicBool,
}
impl ExtensionManager {
pub async fn all_tools(&self) -> Vec<Tool> {
// 检查缓存是否有效
if !self.cache_invalidated.load(Ordering::SeqCst) {
if let Some(cached) = self.tools_cache.read().await.as_ref() {
return cached.clone();
}
}
// 重新获取工具列表
let tools = self.fetch_all_tools().await;
// 更新缓存
*self.tools_cache.write().await = Some(tools.clone());
self.cache_invalidated.store(false, Ordering::SeqCst);
tools
}
}9.2 并行工具执行
rust
impl Agent {
async fn execute_tools_parallel(
&self,
tool_calls: &[ToolCall],
) -> Vec<(String, ToolResult<CallToolResult>)> {
let futures = tool_calls.iter().map(|call| {
let call_id = call.id.clone();
let name = call.name.clone();
let args = call.arguments.clone();
async move {
let result = self.extension_manager
.call_tool(&name, args)
.await;
(call_id, result)
}
});
// 并行执行所有工具调用
futures::future::join_all(futures).await
}
}10. 关键代码路径
10.1 完整请求流程
用户输入
│
▼
┌───────────────────────────────────────────────────────────┐
│ CLI/Desktop: goose session │
│ └── 创建 Agent │
│ └── 初始化 ExtensionManager │
│ └── 连接 MCP Servers │
└───────────────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────────────┐
│ Agent::reply() │
│ ├── 构建系统提示 │
│ ├── 收集所有工具 │
│ └── 开始交互循环 │
└───────────────────────────────────────────────────────────┘
│
▼
┌───────────────────────────────────────────────────────────┐
│ 交互循环 │
│ ├── Provider::chat() → 调用 LLM │
│ ├── 解析响应 │
│ │ ├── 文本 → 返回给用户 │
│ │ └── 工具调用 → 继续处理 │
│ ├── ToolInspectionManager::check() → 安全检查 │
│ ├── ExtensionManager::call_tool() → 执行工具 │
│ │ └── McpClient::call_tool() → MCP 请求 │
│ └── 工具结果添加到对话 → 继续循环 │
└───────────────────────────────────────────────────────────┘
│
▼
输出结果