8205a22972
Ollama adapter (internal/chat/ollama.go): - Implements model.LLM interface for ADK Go - Talks to Ollama's OpenAI-compatible API (/v1/chat/completions) - Full tool/function calling support (tested with Mistral Small 3.2) - Converts ADK types to OpenAI format (messages, tools, tool_calls) - Configurable via OLLAMA_HOST and OLLAMA_MODEL env vars Multi-provider handler: - MODEL_PROVIDER env: "gemini" (default) or "ollama" - Gemini: requires GOOGLE_API_KEY (pay-as-you-go recommended) - Ollama: connects to local or Tailscale-remote instance Rate limiter: - 30 requests/hour per IP on /api/chat endpoint - Uses existing middleware.NewRateLimiter pattern Tested: Ollama + Mistral Small 3.2 on M4 Pro 64GB — correct answers
431 lines
10 KiB
Go
431 lines
10 KiB
Go
// Package chat provides an ADK Go agent that answers questions about CV data.
|
|
package chat
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"iter"
|
|
"net/http"
|
|
|
|
"google.golang.org/adk/model"
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
// OllamaModel implements model.LLM using Ollama's OpenAI-compatible API.
|
|
type OllamaModel struct {
|
|
host string // e.g. "http://localhost:11434"
|
|
modelName string // e.g. "mistral-small3.2"
|
|
client *http.Client
|
|
}
|
|
|
|
// NewOllamaModel creates a new Ollama-backed LLM.
|
|
func NewOllamaModel(host, modelName string) *OllamaModel {
|
|
return &OllamaModel{
|
|
host: host,
|
|
modelName: modelName,
|
|
client: &http.Client{},
|
|
}
|
|
}
|
|
|
|
// Name returns the model name.
|
|
func (m *OllamaModel) Name() string {
|
|
return m.modelName
|
|
}
|
|
|
|
// Verify OllamaModel implements model.LLM at compile time.
|
|
var _ model.LLM = (*OllamaModel)(nil)
|
|
|
|
// GenerateContent sends a request to Ollama and returns ADK-compatible responses.
|
|
func (m *OllamaModel) GenerateContent(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error] {
|
|
return func(yield func(*model.LLMResponse, error) bool) {
|
|
resp, err := m.generate(ctx, req)
|
|
yield(resp, err)
|
|
}
|
|
}
|
|
|
|
// --- OpenAI-compatible request/response types ---
|
|
|
|
type oaiMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content,omitempty"`
|
|
ToolCalls []oaiToolCall `json:"tool_calls,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
}
|
|
|
|
type oaiToolCall struct {
|
|
ID string `json:"id"`
|
|
Type string `json:"type"`
|
|
Function oaiToolFunction `json:"function"`
|
|
}
|
|
|
|
type oaiToolFunction struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"` // JSON string
|
|
}
|
|
|
|
type oaiTool struct {
|
|
Type string `json:"type"`
|
|
Function oaiToolFuncDecl `json:"function"`
|
|
}
|
|
|
|
type oaiToolFuncDecl struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
Parameters any `json:"parameters,omitempty"`
|
|
}
|
|
|
|
type oaiRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []oaiMessage `json:"messages"`
|
|
Tools []oaiTool `json:"tools,omitempty"`
|
|
Stream bool `json:"stream"`
|
|
Temperature *float32 `json:"temperature,omitempty"`
|
|
}
|
|
|
|
type oaiResponse struct {
|
|
Choices []oaiChoice `json:"choices"`
|
|
Usage *oaiUsage `json:"usage,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
}
|
|
|
|
type oaiChoice struct {
|
|
Message oaiMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
}
|
|
|
|
type oaiUsage struct {
|
|
PromptTokens int32 `json:"prompt_tokens"`
|
|
CompletionTokens int32 `json:"completion_tokens"`
|
|
TotalTokens int32 `json:"total_tokens"`
|
|
}
|
|
|
|
// generate performs a synchronous (non-streaming) call to Ollama.
|
|
func (m *OllamaModel) generate(ctx context.Context, req *model.LLMRequest) (*model.LLMResponse, error) {
|
|
oaiReq := m.buildRequest(req)
|
|
|
|
body, err := json.Marshal(oaiReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/v1/chat/completions", m.host)
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: create request: %w", err)
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
|
|
httpResp, err := m.client.Do(httpReq)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: send request: %w", err)
|
|
}
|
|
defer func() { _ = httpResp.Body.Close() }()
|
|
|
|
respBody, err := io.ReadAll(httpResp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ollama: read response: %w", err)
|
|
}
|
|
|
|
if httpResp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("ollama: HTTP %d: %s", httpResp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var oaiResp oaiResponse
|
|
if err := json.Unmarshal(respBody, &oaiResp); err != nil {
|
|
return nil, fmt.Errorf("ollama: unmarshal response: %w", err)
|
|
}
|
|
|
|
return m.convertResponse(&oaiResp)
|
|
}
|
|
|
|
// buildRequest converts an ADK LLMRequest into an OpenAI-compatible request.
|
|
func (m *OllamaModel) buildRequest(req *model.LLMRequest) *oaiRequest {
|
|
oaiReq := &oaiRequest{
|
|
Model: m.modelName,
|
|
Stream: false,
|
|
}
|
|
|
|
// Convert system instruction
|
|
if req.Config != nil && req.Config.SystemInstruction != nil {
|
|
text := extractText(req.Config.SystemInstruction)
|
|
if text != "" {
|
|
oaiReq.Messages = append(oaiReq.Messages, oaiMessage{
|
|
Role: "system",
|
|
Content: text,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Set temperature if provided
|
|
if req.Config != nil && req.Config.Temperature != nil {
|
|
oaiReq.Temperature = req.Config.Temperature
|
|
}
|
|
|
|
// Convert conversation messages
|
|
for _, content := range req.Contents {
|
|
msgs := convertContent(content)
|
|
oaiReq.Messages = append(oaiReq.Messages, msgs...)
|
|
}
|
|
|
|
// Convert tools (function declarations)
|
|
if req.Config != nil && req.Config.Tools != nil {
|
|
for _, t := range req.Config.Tools {
|
|
if t.FunctionDeclarations != nil {
|
|
for _, fd := range t.FunctionDeclarations {
|
|
oaiReq.Tools = append(oaiReq.Tools, convertFunctionDecl(fd))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return oaiReq
|
|
}
|
|
|
|
// convertContent converts a genai.Content into one or more OpenAI messages.
|
|
func convertContent(content *genai.Content) []oaiMessage {
|
|
if content == nil {
|
|
return nil
|
|
}
|
|
|
|
role := mapRole(content.Role)
|
|
|
|
// Check if this content has function calls (assistant with tool_calls)
|
|
var toolCalls []oaiToolCall
|
|
var textParts []string
|
|
var funcResponses []oaiMessage
|
|
|
|
for _, part := range content.Parts {
|
|
if part.Text != "" {
|
|
textParts = append(textParts, part.Text)
|
|
}
|
|
if part.FunctionCall != nil {
|
|
argsJSON, _ := json.Marshal(part.FunctionCall.Args)
|
|
toolCalls = append(toolCalls, oaiToolCall{
|
|
ID: part.FunctionCall.ID,
|
|
Type: "function",
|
|
Function: oaiToolFunction{
|
|
Name: part.FunctionCall.Name,
|
|
Arguments: string(argsJSON),
|
|
},
|
|
})
|
|
}
|
|
if part.FunctionResponse != nil {
|
|
respJSON, _ := json.Marshal(part.FunctionResponse.Response)
|
|
funcResponses = append(funcResponses, oaiMessage{
|
|
Role: "tool",
|
|
Content: string(respJSON),
|
|
ToolCallID: part.FunctionResponse.ID,
|
|
})
|
|
}
|
|
}
|
|
|
|
var msgs []oaiMessage
|
|
|
|
// Build the primary message
|
|
if len(toolCalls) > 0 {
|
|
// Assistant message with tool calls
|
|
msg := oaiMessage{
|
|
Role: "assistant",
|
|
ToolCalls: toolCalls,
|
|
}
|
|
if len(textParts) > 0 {
|
|
combined := ""
|
|
for _, t := range textParts {
|
|
combined += t
|
|
}
|
|
msg.Content = combined
|
|
}
|
|
msgs = append(msgs, msg)
|
|
} else if len(textParts) > 0 {
|
|
combined := ""
|
|
for _, t := range textParts {
|
|
combined += t
|
|
}
|
|
msgs = append(msgs, oaiMessage{
|
|
Role: role,
|
|
Content: combined,
|
|
})
|
|
}
|
|
|
|
// Append function response messages separately
|
|
msgs = append(msgs, funcResponses...)
|
|
|
|
return msgs
|
|
}
|
|
|
|
// convertFunctionDecl converts a genai FunctionDeclaration to an OpenAI tool.
|
|
func convertFunctionDecl(fd *genai.FunctionDeclaration) oaiTool {
|
|
var params any
|
|
if fd.Parameters != nil {
|
|
params = convertSchema(fd.Parameters)
|
|
} else if fd.ParametersJsonSchema != nil {
|
|
params = fd.ParametersJsonSchema
|
|
}
|
|
|
|
return oaiTool{
|
|
Type: "function",
|
|
Function: oaiToolFuncDecl{
|
|
Name: fd.Name,
|
|
Description: fd.Description,
|
|
Parameters: params,
|
|
},
|
|
}
|
|
}
|
|
|
|
// convertSchema converts a genai.Schema to a JSON-Schema-compatible map.
|
|
func convertSchema(s *genai.Schema) map[string]any {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
m := make(map[string]any)
|
|
|
|
if s.Type != "" {
|
|
m["type"] = schemaTypeToJSON(s.Type)
|
|
}
|
|
if s.Description != "" {
|
|
m["description"] = s.Description
|
|
}
|
|
if len(s.Enum) > 0 {
|
|
m["enum"] = s.Enum
|
|
}
|
|
if s.Items != nil {
|
|
m["items"] = convertSchema(s.Items)
|
|
}
|
|
if len(s.Properties) > 0 {
|
|
props := make(map[string]any)
|
|
for k, v := range s.Properties {
|
|
props[k] = convertSchema(v)
|
|
}
|
|
m["properties"] = props
|
|
}
|
|
if len(s.Required) > 0 {
|
|
m["required"] = s.Required
|
|
}
|
|
|
|
return m
|
|
}
|
|
|
|
// schemaTypeToJSON maps genai.Type to JSON Schema type strings.
|
|
func schemaTypeToJSON(t genai.Type) string {
|
|
switch t {
|
|
case genai.TypeString:
|
|
return "string"
|
|
case genai.TypeNumber:
|
|
return "number"
|
|
case genai.TypeInteger:
|
|
return "integer"
|
|
case genai.TypeBoolean:
|
|
return "boolean"
|
|
case genai.TypeArray:
|
|
return "array"
|
|
case genai.TypeObject:
|
|
return "object"
|
|
default:
|
|
return "string"
|
|
}
|
|
}
|
|
|
|
// convertResponse converts an OpenAI response back to an ADK LLMResponse.
|
|
func (m *OllamaModel) convertResponse(resp *oaiResponse) (*model.LLMResponse, error) {
|
|
if len(resp.Choices) == 0 {
|
|
return nil, fmt.Errorf("ollama: empty response (no choices)")
|
|
}
|
|
|
|
choice := resp.Choices[0]
|
|
var parts []*genai.Part
|
|
|
|
// Handle tool calls
|
|
if len(choice.Message.ToolCalls) > 0 {
|
|
for _, tc := range choice.Message.ToolCalls {
|
|
var args map[string]any
|
|
if tc.Function.Arguments != "" {
|
|
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
|
// If args aren't valid JSON, wrap them
|
|
args = map[string]any{"raw": tc.Function.Arguments}
|
|
}
|
|
}
|
|
parts = append(parts, &genai.Part{
|
|
FunctionCall: &genai.FunctionCall{
|
|
ID: tc.ID,
|
|
Name: tc.Function.Name,
|
|
Args: args,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// Handle text content
|
|
if choice.Message.Content != "" {
|
|
parts = append(parts, &genai.Part{
|
|
Text: choice.Message.Content,
|
|
})
|
|
}
|
|
|
|
content := &genai.Content{
|
|
Parts: parts,
|
|
Role: genai.RoleModel,
|
|
}
|
|
|
|
llmResp := &model.LLMResponse{
|
|
Content: content,
|
|
FinishReason: mapFinishReason(choice.FinishReason),
|
|
TurnComplete: true,
|
|
ModelVersion: resp.Model,
|
|
}
|
|
|
|
// Map usage metadata
|
|
if resp.Usage != nil {
|
|
llmResp.UsageMetadata = &genai.GenerateContentResponseUsageMetadata{
|
|
PromptTokenCount: resp.Usage.PromptTokens,
|
|
CandidatesTokenCount: resp.Usage.CompletionTokens,
|
|
TotalTokenCount: resp.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
return llmResp, nil
|
|
}
|
|
|
|
// mapRole converts genai roles to OpenAI roles.
|
|
func mapRole(role string) string {
|
|
switch role {
|
|
case "user":
|
|
return "user"
|
|
case "model":
|
|
return "assistant"
|
|
default:
|
|
return "user"
|
|
}
|
|
}
|
|
|
|
// mapFinishReason converts OpenAI finish reasons to genai finish reasons.
|
|
func mapFinishReason(reason string) genai.FinishReason {
|
|
switch reason {
|
|
case "stop":
|
|
return genai.FinishReasonStop
|
|
case "length":
|
|
return genai.FinishReasonMaxTokens
|
|
case "tool_calls":
|
|
return genai.FinishReasonStop // Tool calls are a normal stop
|
|
default:
|
|
return genai.FinishReasonStop
|
|
}
|
|
}
|
|
|
|
// extractText extracts all text from a genai.Content.
|
|
func extractText(content *genai.Content) string {
|
|
if content == nil {
|
|
return ""
|
|
}
|
|
var result string
|
|
for _, part := range content.Parts {
|
|
if part.Text != "" {
|
|
result += part.Text
|
|
}
|
|
}
|
|
return result
|
|
}
|