RAG Pipeline
Complete working RAG pipeline. Reads markdown files from a docs/ directory, chunks them, embeds with Ollama, and answers questions from the content. Create a docs/ folder with a few .md files before running.
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Chunk struct {
Text string
Source string
Embedding []float64
}
// --- Chunking ---
func chunk(text string, size int) []string {
paragraphs := strings.Split(text, "\n\n")
var chunks []string
var current strings.Builder
for _, p := range paragraphs {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if current.Len()+len(p) > size && current.Len() > 0 {
chunks = append(chunks, current.String())
current.Reset()
}
if current.Len() > 0 {
current.WriteString("\n\n")
}
current.WriteString(p)
}
if current.Len() > 0 {
chunks = append(chunks, current.String())
}
return chunks
}
// --- Embedding ---
func embed(text string) ([]float64, error) {
body, _ := json.Marshal(map[string]any{
"model": "nomic-embed-text",
"input": text,
})
resp, err := http.Post("http://localhost:11434/api/embed", "application/json", bytes.NewReader(body))
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result struct {
Embeddings [][]float64 `json:"embeddings"`
}
json.Unmarshal(data, &result)
return result.Embeddings[0], nil
}
func cosineSimilarity(a, b []float64) float64 {
var dot, normA, normB float64
for i := range a {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0
}
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
}
// --- Indexing ---
func indexDocuments(dir string) ([]Chunk, error) {
files, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
var chunks []Chunk
for _, f := range files {
if !strings.HasSuffix(f.Name(), ".md") {
continue
}
content, _ := os.ReadFile(filepath.Join(dir, f.Name()))
parts := chunk(string(content), 600)
for _, part := range parts {
vec, err := embed(part)
if err != nil {
return nil, fmt.Errorf("embed %s: %w", f.Name(), err)
}
chunks = append(chunks, Chunk{Text: part, Source: f.Name(), Embedding: vec})
}
fmt.Printf(" %s: %d chunks\n", f.Name(), len(parts))
}
return chunks, nil
}
// --- Retrieval ---
func retrieve(docs []Chunk, query string, topK int) []Chunk {
queryVec, _ := embed(query)
type scored struct {
chunk Chunk
score float64
}
results := make([]scored, len(docs))
for i, doc := range docs {
results[i] = scored{doc, cosineSimilarity(queryVec, doc.Embedding)}
}
sort.Slice(results, func(i, j int) bool {
return results[i].score > results[j].score
})
top := make([]Chunk, 0, topK)
for i := 0; i < topK && i < len(results); i++ {
top = append(top, results[i].chunk)
}
return top
}
// --- Generation ---
func chat(messages []Message) (string, error) {
body, _ := json.Marshal(map[string]any{
"model": "llama3.2",
"messages": messages,
"stream": false,
})
resp, err := http.Post("http://localhost:11434/api/chat", "application/json", bytes.NewReader(body))
if err != nil {
return "", err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var result struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
}
json.Unmarshal(data, &result)
return result.Message.Content, nil
}
func rag(docs []Chunk, question string) (string, error) {
relevant := retrieve(docs, question, 3)
var context strings.Builder
for i, c := range relevant {
fmt.Fprintf(&context, "[%d] (%s)\n%s\n\n", i+1, c.Source, c.Text)
}
messages := []Message{
{Role: "system", Content: `Answer using ONLY the provided context.
If the context doesn't have the answer, say "I don't have that information."
Cite sources using [1], [2], etc.`},
{Role: "user", Content: fmt.Sprintf("Context:\n%s\nQuestion: %s",
context.String(), question)},
}
return chat(messages)
}
// --- Main ---
func main() {
fmt.Println("Indexing docs/...")
docs, err := indexDocuments("./docs")
if err != nil {
fmt.Println("Error:", err)
return
}
fmt.Printf("Indexed %d chunks\n", len(docs))
scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print("\nQuestion: ")
if !scanner.Scan() {
break
}
q := scanner.Text()
if q == "" {
continue
}
answer, err := rag(docs, q)
if err != nil {
fmt.Println("Error:", err)
continue
}
fmt.Println("\n" + answer)
}
}