ในยุคนี้ที่ AI และ Natural Language Processing (NLP) พัฒนาอย่างรวดเร็วและก้าวกระโดด โมเดลภาษาขนาดใหญ่ (Large Language Models: LLMs) อย่าง ChatGPT, Gemini และ AI เจ้าอื่นๆ ได้เข้ามากลายเป็นเพื่อนรู้ใจที่หลายคนใช้ปรึกษาเรื่องส่วนตัว ไปจนถึงเลขาที่รอบรู้ สามารถแพลนงานหรือตอบคำถามได้รอบด้าน ทำให้ LLM ถูกนำไปใช้ในหลากหลายอุตสาหกรรม ยกตัวอย่างเช่น
แต่ถึงแม้ LLM จะมีความสามารถรอบด้าน ตอบคำถามหลากหลายและสร้างข้อความได้ราวกับมนุษย์ แต่ก็ยังมี ข้อจำกัดสำคัญอยู่ นั่นก็คือ LLM ไม่สามารถเข้าถึงข้อมูลเฉพาะ เช่นข้อมูลภายในองค์กร หรือข้อมูลใหม่ล่าสุดที่อยู่นอกเหนือชุดข้อมูลที่ถูกเทรนมา
ถ้าข้อมูลไหนไม่ได้อยู่ในชุดเทรน LLM ก็จะนึกไม่ออกเลย เช่น ถามถึงตัวเรา ที่อาจไม่ได้มีชื่อเสียง, ข่าวหรือเหตุการณ์ใหม่ๆ หรือข้อมูลภายในองค์กรเป็นต้น
ผลคือ LLM ก็จะ งง และไม่สามารถตอบคำถามได้ ถ้าโดนถามอะไรที่อยู่นอกเหนือจากที่มันเคยเทรนมา ! 😵💫
เพื่อแก้ปัญหานี้ เราเลยต้องใช้เทคนิคเสริมพลังความรู้ให้ AI เช่น RAG (Retrieval-Augmented Generation) ที่ช่วยให้ LLM ไปค้นข้อมูลสดๆจากฐานความรู้ของเราได้ RAG จึงทำให้ LLM สร้างคำตอบที่แม่นยำและมีความเกี่ยวข้องกับบริบทมากขึ้น
เราไปเริ่มจากการ “meet & greet” RAG กันก่อนเล้ย!
ใน content นี้เราจะขอเกริ่นถึง RAG แบบคร่าวๆ ก่อนที่เราจะไปดูว่า CAG คืออะไร. . .
ถ้าอยากอ่าน RAG เชิงเทคนิคแบบจัดเต็ม แนะนำบทความนี้เลย! 👉🏼 ทำ On-Premise RAG ใช้เองง่าย ๆ ด้วย Ollama และ Python
Retrieval-Augmented Generation หรือ RAG
คือเทคนิคที่ให้ LLM ของเราออกไปค้นข้อมูลจาก knowledge base หรือ ฐานความรู้ของเรา ทุกครั้งที่มีคนถาม (เหมือนมีบรรณารักษ์ที่วิ่งไปหาหนังสือในห้องสมุดทุกครั้ง)
ข้อดี คือฉลาดและข้อมูลอัปเดตได้ตลอด แต่ ข้อเสีย คือ… บางที LLM อาจจะช้า เพราะมัวแต่เดินหาหนังสืออยู่ (หรือเรียกว่า Retrieval Latency นั่นเอง 🐢)
Concept การทำงานของ RAG แบบคร่าวๆ
ถ้ามองแบบง่ายๆ RAG เหมือนมีทีมงานคู่หูที่เข้าขากัน
เจ้านี่คือมือโปรด้านการค้นหา วิ่งวุ่นไปทั่วคลังข้อมูลหรือฐานความรู้ของเรา เพื่อควานหาข้อมูลเด็ดๆ ที่เกี่ยวข้องกับคำถามของเราให้มากที่สุด
(ใช้เทคนิคอย่าง vector embeddings เพื่อแปลงข้อความให้เป็นเวกเตอร์ (Vector) ซึ่งเป็นชุดตัวเลขที่แทนความหมายของข้อมูลนั้น กับ Similarity Search เอาเวกเตอร์ไปเทียบกับเวกเตอร์อื่นๆ ในฐานข้อมูล แล้วหาว่าอันไหนคล้ายกันที่สุด)
พอได้วัตถุดิบ (ข้อมูล) มาแล้ว เชฟ AI ก็จะนำคำถามของเรา + ข้อมูลที่ Retriever ไปหามา มาปรุงเป็นคำตอบใหม่ที่ทั้งตรงประเด็นและมีรายละเอียด
(เหมือนเชฟที่เอาวัตถุดิบสดใหม่มาทำเมนูพิเศษให้เรานั่นเอง !)
สองขั้นตอนนี้ช่วยให้ RAG ตอบได้แม่นยำ อ้างอิงข้อเท็จจริง และตรงใจเรากว่า LLM ทั่วไปที่รู้แค่สิ่งที่เคยถูกเทรนมาเท่านั้น
🎯 ข้อดีของ RAG
RAG มีข้อดีหลายอย่างที่น่าสนใจ . . .
อย่างแรกคือความสามารถในการอัปเดตข้อมูลได้ง่ายๆ แค่โยนข้อมูลใหม่เข้า ฐานข้อมูล เหมือนมีทีมงานที่พร้อมอัปเดตความรู้ตลอดเวลา 🆕
ข้อดีที่สองคือช่วยลด Hallucination ได้ดี เพราะมันตอบจากข้อมูลจริงๆ และสามารถอ้างอิงแหล่งข้อมูลได้ เหมือนมีหลักฐานประกอบการตอบ 🎯
ยิ่งไปกว่านั้น RAG ยังรองรับข้อมูลได้เยอะมาก ไม่ว่าเอกสารหรือข้อมูลเยอะมากขนาดไหน ก็ไม่หวั่น เหมือนมีห้องสมุดขนาดใหญ่และบรรณารักษ์ที่ค้นหาได้เร็ว 📚
สุดท้ายคือความยืดหยุ่นในการค้นหา เราสามารถปรับแต่ง retrieval strategy และเลือกใช้ embedding model ตามความเหมาะสมได้ 🔍
⚠️ แต่ RAG ก็มีข้อเสียที่ต้องระวังเหมือนกัน !
อย่างแรกคือต้องดูแลฐานข้อมูลให้และจัดการ embedding ให้ถูกต้อง เหมือนต้องดูแลห้องสมุดให้เป็นระเบียบตลอดเวลา 🛠️
ข้อเสียที่สองคือมี Retrieval Latency ต้องรอการค้นหาข้อมูล เหมือนต้องรอบรรณารักษ์ไปหาหนังสือมาให้ ⏳
นอกจากนี้ยังอาจเจอปัญหา Chunking ถ้าแบ่งข้อมูลไม่ดี อาจได้ context ไม่ครบ เหมือนถ้าแบ่งหนังสือผิดหน้า อาจได้เนื้อหาไม่ครบ 📄
สุดท้ายคือต้องจัดการ embedding ให้ดี ต้องเลือก model ให้เหมาะและตั้งค่า parameters ให้ถูกต้อง เหมือนต้องสอนทีมงานให้เข้าใจภาษาเดียวกัน 🔢
แล้ว . . . อะไรคือ CAG ล่ะ ?
Cache-Augmented Generation หรือ CAG
คือทางเลือกสายสปีด 🚀 แทนที่จะวิ่งไปค้นข้อมูลทุกครั้ง CAG จะกินหรือ Preload ข้อมูลที่ต้องใช้เข้าไปยัง context window ของ LLM ตั้งแต่แรก แล้วให้ LLM จำไว้ในหัว เรียกว่า Context Cache (หรือ KV cache)
แบบนี้ LLM จะตอบเร็วมากกกกก เพราะไม่ต้องวิ่งไปไหนอีกแล้ว — ข้อมูลอยู่ในหัวหมดแล้ว 🚀
แต่ๆๆ…ที่เกริ่นไว้ด้านบนว่า CAG จะทำการ Preload ข้อมูลไปยัง context window ของ LLM ตั้งแต่แรก นั่นหมายความว่า ถ้า size ของ context window ของ LLM (จำนวน token สูงสุดที่ LLM รับได้ใน 1 prompt) เล็กกว่าข้อมูลหรือ knowledge source ของเรา หรือพูดง่ายๆว่าปริมาณข้อมูลของเรามีขนาดใหญ่กว่าที่ context window ของ LLM จะรับได้ CAG ก็จะมีข้อมูลที่ได้รับมาไม่ครบ และทำให้คำตอบที่ได้ออกมา อาจจะไม่ครบหรือผิดพลาดได้
ดังนั้น ถ้าข้อมูลของเรามีปริมาณไม่ใหญ่เกินไป หรือเพียงพอที่จะ “แปะลงบนขนมปังช่วยจำ” (context window) ได้หมด CAG ก็จะตอบไว ตอบครบ ไม่ต้องเสียเวลาค้นหาอะไรอีก
แต่ถ้าข้อมูลเยอะเกินไป CAG ก็จะจำข้อมูลไม่ได้ทั้งหมด ก็เหมือนพยายามแปะขนมปังช่วยจำไปยังข้อมูลที่มีขนาดเท่ากับกระดาษ A1 ซึ่งใหญ่กว่าขนมปังช่วยจำ — สุดท้ายก็ต้องเลือกแปะเฉพาะส่วนที่สำคัญ หรือไม่ก็ต้องหาวิธีใหม่นั่นเอง. . .
เราไปดูกันว่า CAG ทำงานยังไง และไปลอง Code กัน !
หลักการทำงานของ CAG แบบเข้าใจง่ายๆ
1. เตรียมข้อมูล (Preparing Knowledge Source): จากรูปด้านบน เราจะต้องทำการเตรียม knowledge source หรือ ข้อมูล ของเราให้พร้อมที่จะถูก Preload ไปยัง LLM ซึ่งอาจจะเป็นไฟล์ txt, pdf, image, Excel, Powerpoint, Markdown เป็นต้น มาแปลง (extract) ให้อยู่ในรูปแบบของ text เช่น .txt, .md (ในกรณีที่ knowledge source ไม่ใช่ไฟล์ที่เป็น txt หรือ markdown) เพราะว่า LLMs สามารถเข้าใจได้แค่ text ยังไงล่ะ
2. โหลดข้อมูลเข้า LLM’s context window (Preloading Knowledge): เมื่อข้อมูลของเราพร้อมที่จะใช้งานแล้ว Step ต่อมาก็คือการนำข้อมูลเหล่านั้นมาโหลด หรือ “Preload” เข้า context window ของ LLM และสร้าง KV Cache เก็บไว้ในตัวของ LLM เองเลย
3. เวลามี query (Answering with Preloaded Knowledge): เจ้า LLM ที่ได้ผ่านการ Preload ข้อมูลข้อมูลมาแล้วก็เหมือนมี Cheat Sheet ที่ใช้นำไปเข้าห้องสอบ ทำให้สามารถหยิบข้อมูลที่ถูก cache ไว้ในตอนแรกมาใช้ตอบคำถามได้เลย!
. . .แล้ว CAG จัดการกับบริบทหรือ knowledge สำหรับตอบคำถามยังไงล่ะ ?
CAG ใช้กลไกที่เรียกว่า Self-Attention Mechanism ซึ่งเป็นหัวใจสำคัญของ Transformer models โดยทำงานดังนี้
แนวคิดของ Self-Attention คือ นำคำทุกคำในประโยคมาเปรียบเทียบกันเอง ให้โมเดลเรียนรู้ และเลือกเอง ว่าจะสนใจคำไหน เมื่อไหร่ ด้วยการ แปลง Input ให้เป็น 3 Vectors คือ
ทั้งหมดนี้ได้จากการ คูณ input (X) กับ weight matrix (WQ, WK, WV) ซึ่งเป็น linear layer ที่โมเดลเรียนรู้เองระหว่างการเทรน
จากนั้นนำไปคำนวณ attention score โดยการ:
ขั้นตอนสุดท้ายคือการหา Weighted Sum (Output) โดยการ นำ Attention Weights ที่ได้ไปคูณกับ V ของแต่ละคำแล้วรวมกัน เราจะได้ vector ใหม่สำหรับแต่ละคำในประโยค เพื่อสรุปบริบทสำคัญ สำหรับแต่ละคำ หรือพูดง่าย ๆ คือ “สำหรับคำๆนี้ ถ้าอยากเข้าใจมันดีขึ้น ต้องใส่ใจคำไหนในประโยคบ้าง ”
ในโมเดลจริงจะมีหลาย head (Multi-Head) เพื่อทำ self-attention พร้อมกัน แล้วนำผลลัพธ์มารวม (concatenate) และผ่าน linear layer อีกที
⚠️ Note: self-attention ทำงานอัตโนมัติในตัว LLM เราไม่ต้องเขียนโค้ดเพิ่มเติม
หลังจากนี้ โมเดลจะใช้ผลลัพธ์จาก self-attention ไปคำนวณความน่าจะเป็นของ token ถัดไป แล้วเลือก token ที่จะตอบ (เช่นด้วย greedy decoding)
. . .ถ้าอ่านมาถึงตรงนี้ คุณได้เห็นความแตกต่างระหว่าง RAG และ CAG แล้ว
CAG ไม่มี Information Retrieval (IR) หรือเจ้าตัว Retriever ที่เกริ่นไปก่อนหน้าเหมือนกับ RAG ทำให้เราไม่ต้องรอการ Retrieve ข้อมูลที่เกี่ยวข้องจากฐานข้อมูล แต่ CAG ใช้ caches ที่เก็บไว้ใน context window ของ LLM มาใช้ในการตอบคำถามนั่นเอง
และด้วยเหตุนี้เอง ทำให้ CAG สามารถตอบคำถามได้รวดเร็วกว่า RAG และไม่จำเป็นที่จะต้องแนบ relevant documents หรือบริบทที่เกี่ยวของไปพร้อมกับ query อย่างที่ RAG ทำ
เอาล่ะ! ใน tutorial นี้ เราจะพาไปดูวิธีสร้าง CAG แบบง่ายๆ:
เรามาเริ่มจากการเตรียมของที่ต้องใช้ให้พร้อมกันก่อน
1. HuggingFace Account — (เอาไว้โหลด LLM models ฟรี)
2. ไฟล์ document.txt — ที่เปรียบเสมือนเป็น Knowledge Source ของเราที่พร้อมใช้งานแล้ว (ในตัวอย่างนี้จะใช้เป็น Company FAQ ที่ใช้ตอบคำถามที่เกี่ยวข้องกับ PALO-IT)
3. Python — แนะนำให้ใช้ Python เวอร์ชัน 3.12 ขึ้นไป เพื่อรองรับไลบรารีใหม่ๆ ได้ครบถ้วน แนะนำให้ใช้ pyenv ในการติดตั้งจะได้ control python version ได้ง่ายๆ
เราจะใช้ 3 libraries หลักๆ ที่ขาดไม่ได้เลย
torch สำหรับ Pytorchtransformers กล่องเครื่องมือ ที่ช่วยให้เราใช้งาน LLM ได้ง่ายและสะดวกขึ้น เช่น แปลงข้อความเป็น token และแปลง token กลับเป็นข้อความอัตโนมัติ, โหลดโมเดลภาษา (เช่น GPT, LLaMA, Mistral) มาใช้งานได้ในบรรทัดเดียวDynamicCache ตัวช่วยเก็บ KV cachesภาพรวมของ structure จะเป็นประมาณนี้
cag
├── document.txt # Knowledge source ของเรา
├── cag.py # ไฟล์หลักของเราที่จะใช้ run script
├── .env # เก็บ environment variable ต่างๆ
└── requirements.txt # python lib(s) ที่เราจะใช้
ใน requirement.txt มี libraries ดังนี้
torch
transformers
เราจะเก็บ Hugging Face variables ไว้ใน .envตามนี้เลย
HF_TOKEN=your_HF_token_here
ถึงเวลาไฟล์หลักที่เป็นหัวใจของ CAG: cag.py
import os
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from dotenv import load_dotenv
ในตัวอย่างนี้เราใช้เป็น Mistral-7B-Instruct-v0.1 ซึ่งเป็นโมเดลขนาดกลางที่เหมาะสมทั้งด้านคุณภาพ, ความเร็ว, และ resource ที่ใช้ เหมาะกับงาน Chatbot, FAQ, หรือ CAG ที่ต้องการ balance ระหว่างประสิทธิภาพและความประหยัด และยังมี context window ที่รับได้ 32.8k tokens ซึ่งเพียงพอกับ knowledge source ของเรา
load_dotenv()
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
HF_TOKEN = os.getenv("HF_TOKEN")
ถ้า knowledge source ใหญ่กว่านี้ คุณอาจต้องเลือกโมเดลที่รองรับ context window ที่ใหญ่ขึ้น เช่น
เราควรเลือกโมเดลให้เหมาะกับขนาด knowledge source และ resource ที่มี
ถ้า context window ไม่พอ ให้พิจารณา chunk ข้อมูล หรืออาจจะเปลี่ยนไปใช้ RAG แทน
ก่อนจะใช้งาน LLM อย่าง Mistral-7B เราต้องเลือกก่อนว่าจะรันบน GPU (เร็วสุด) หรือ CPU (ถ้าไม่มี GPU) และเลือก precision ที่เหมาะสม (FP16 สำหรับ GPU, FP32 สำหรับ CPU) เพื่อให้ประหยัด resource และได้ performance ที่ดีที่สุด
ถ้ามี GPU (เช่น NVIDIA RTX, A100, T4 ฯลฯ) จะใช้ FP16 (half precision) เพื่อประหยัด VRAM และรันเร็วขึ้น ถ้าไม่มี GPU จะ fallback ไปใช้ CPU และ FP32 (full precision) แทน
# ฟังก์ชันนี้ช่วยเลือกว่าจะรันบน GPU (เร็วสุด) หรือ CPU (ถ้าไม่มี GPU)
# และเลือก dtype(data type)ให้เหมาะสม
def get_device_and_dtype():
if torch.cuda.is_available():
return torch.device("cuda"), torch.float16
else:
return torch.device("cpu"), torch.float32
# โหลด model และ tokenizer เพื่อนำมาสร้าง tokens
def load_model_and_tokenizer(model_name, hf_token):
device, dtype = get_device_and_dtype()
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
token=hf_token
)
model.to(device)
return tokenizer, model, device
# เราจะได้ tokenizer, model และ device เป็น output ของ function นี้
_tokenizer, _model, _device = load_model_and_tokenizer(MODEL_NAME, HF_TOKEN)
ขอขยายความว่า dtype (data type) คืออะไร ?
dtype คือ “ชนิดของข้อมูล” ที่ใช้เก็บตัวเลขใน tensor หรือ array ของโมเดล AI/Deep Learning
โดยเฉพาะใน PyTorch หรือ NumPy จะมี dtype หลายแบบ เช่น
ในบริบทของ LLM:
อย่าลืม! ต้องใส่ HuggingFace Token ของตัวเอง (HF_TOKEN) เพื่อโหลดโมเดลจาก HuggingFace Hub
ฟังก์ชันนี้จะอ่านไฟล์ document.txt ของเรา (หรือ Company FAQ) แล้วคืนข้อความทั้งหมดมาใช้เป็น context ของเราที่จะนำไปรวมกับ system prompt ในขั้นตอนถัดไป และโหลดเข้า context window ของ LLM ต่อ
(ถ้า knowledge source ของเรายังเป็น format อื่นที่ไม่ใช่ text เช่น รูปภาพ, pdf ฯลฯ ก็ต้องมา extract หรือแปลงให้เป็น .txt หรือ .md ก่อน)
def load_knowledge_for_cag(path="document.txt"):
with open(path, "r", encoding="utf-8") as f:
return f.read()
print("Loading knowledge...")
faq_text = load_knowledge_for_cag()
print(f"Loaded ({len(faq_text)} characters)")
ในขั้นตอนนี้ เราจะทำการ construct หรือสร้าง system prompt ของเรา ที่จะนำไปป้อนให้เจ้า LLM โดยการรวม context (knowledge source) ของเราเข้ากับ system prompt เพื่อให้ LLM เข้าใจบทบาทและข้อมูลที่ต้องใช้
print("[CAG] Preparing system prompt...")
system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{faq_text}
Question:
""".strip()
ผู้อ่านอาจจะสงสัยว่า <|system|> กับ <|user|> คืออะไร ?
เจ้า 2 ตัวนี้ใน code ตัวอย่าง เรียกว่า “Special tokens” หรือ “โทเคนพิเศษ” ที่ถูกออกแบบมาเพื่อบอก LLM ว่า “ตอนนี้กำลังอยู่ในบทบาทไหน” หรือ “ข้อความนี้เป็นของใคร”
ใน LLM รุ่นใหม่ๆ (เช่น Mistral, Llama, ChatGPT, Gemini, Claude) จะมีการใช้ special tokens เพื่อแยกบทสนทนา, system prompt, หรือ context ต่างๆ
ตัวอย่างใน code เรา
เปรียบเทียบง่ายๆ:
หมายเหตุ: แต่ละโมเดลอาจใช้ special tokens ไม่เหมือนกัน (เช่น <s>, <|assistant|>, [INST] ฯลฯ) ควรดู documentation ของโมเดลนั้นๆ ว่า support โทเคนแบบไหน
ในขั้นตอนนี้ เราจะสร้าง KV Cache (Key-Value Cache) ซึ่งเป็นกลไกสำคัญของโมเดล Transformer ที่ช่วยให้ LLM “จำ” context หรือ knowledge ทั้งหมดที่เรา preload ไว้ การมี KV Cache ทำให้เวลาตอบคำถามใหม่ๆ โมเดลไม่ต้องประมวลผล context ทั้งหมดซ้ำอีกครั้ง — ประหยัดเวลาและ resource มาก!
# ฟังก์ชันที่ใช้สร้าง KV cache สำหรับ context ของเรา
def get_kv_cache(model, tokenizer, prompt: str):
print("[CAG] Building KV cache...")
t1 = time.time()
# บอกว่า embedding layer ของโมเดลนี้อยู่บน device ไหน (เช่น "cuda:0" หรือ "cpu")
# เพื่อให้แน่ใจว่า input IDs ที่เราสร้างขึ้นจะถูกส่งไปยัง device เดียวกับโมเดล (ถ้าไม่ตรงกันจะ error)
device = model.model.embed_tokens.weight.device
# ทำการ tokenize prompt ของเราให้เป็น input IDs
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# object ที่ใช้เก็บ key-value pairs (kv cache)
cache = DynamicCache()
# context manager ของ PyTorch ที่บอกว่า "ในบล็อกนี้ ไม่ต้องเก็บ gradient"
# ปกติการรันโมเดลจะเก็บ gradient สำหรับการเทรน แต่สำหรับ inference หรือการสร้าง cache เราไม่ต้องการ gradient
with torch.no_grad():
_ = model(input_ids=input_ids, past_key_values=cache, use_cache=True)
t2 = time.time()
print(f"[CAG] KV cache built. Length: {input_ids.shape[-1]}. Time: {t2-t1:.2f}s")
return cache, input_ids.shape[-1]
# ได้ kv cache, original length ของ cache เป็น output
_kv_cache, _origin_len = get_kv_cache(_model, _tokenizer, system_prompt)
get_kv_cache คือ ฟังก์ชันที่สร้าง KV cache สำหรับ context ของเราโดยรับ input 3 อย่าง ดังนี้
model: โมเดล LLM ที่จะใช้ encode contexttokenizer: ตัวแปลงข้อความเป็น token IDsprompt: ข้อความ context ที่จะ preload (เช่น knowledge ทั้งหมดจาก document.txt)ในฟังก์ชัน:
ฟังก์ชันนี้มีหน้าที่ ตัด หรือ trim ข้อมูลใน KV cache (ทั้ง key และ value) ให้ cache กลับมามีขนาดเท่าเดิม เหมือนตอนที่เราสร้าง cache จาก step ก่อนหน้า
หรือพูดง่ายๆก็คือทำให้ขนาดของ cache กลับมาเท่ากับ _origin_len ที่เราได้มาจาก step ก่อนหน้า
def clean_up(cache: DynamicCache, origin_len: int):
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
เหตุผลที่เราต้อง clean up หรือรีเซ็ต cache นั้นก็เพราะว่าเวลาที่ LLM generate คำตอบใหม่ๆ ต่อท้าย context เดิม, KV cache จะขยายออกไปเรื่อยๆ
ถ้าไม่ reset หรือ trim cache ก่อน อาจเกิด “cache ปน” ทำให้คำตอบผิดเพี้ยนได้
เพื่อให้เห็นภาพ เรามาดูตัวอย่างด้านล่างกัน
ตัว prompt ที่ LLM เห็น (ในเชิง KV cache) จะ “ต่อท้าย” query ใหม่เข้าไปหลัง sequence เดิมแบบนี้ 👇🏼
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
[FAQ ทั้งหมด]
Question:
PALO IT คืออะไร?
2. ถ้าไม่ clean up แล้วมี query ใหม่ เช่น “PALO IT ทำอะไรบ้าง?”
prompt ที่ LLM เห็น (ในเชิง KV cache) จะกลายเป็น:
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
[FAQ ทั้งหมด]
Question:
PALO IT คืออะไร?
PALO IT ทำอะไรบ้าง?
ผลลัพธ์ก็คือออ . . . LLM จะเห็นทั้ง query เก่าและ query ใหม่ใน prompt เดียวกัน อาจทำให้โมเดลสับสน ตอบผิด หรือเอาคำถามเก่ามาปนกับคำถามใหม่ ถ้าทำหลายรอบ prompt จะยาวขึ้นเรื่อยๆ (query สะสมต่อท้าย context เดิม) และนี่คือเหตุผลว่าทำไมเราควร clean up cache ทุกครั้งที่มี query เข้ามานั่นเอง ✨
ขั้นตอนนี้เป็นขั้นตอนสำคัญที่จะให้ LLM generate คำตอบออกมาจาก cache ที่ถูก preload ไว้แล้ว
ฟังก์ชัน generate นี้คือหัวใจของการสร้างคำตอบ จาก LLM โดยใช้ KV cache ที่เตรียมไว้ LLM จะ generate ข้อความแบบ token-by-token (ทีละคำ) ด้วยวิธี greedy decoding (คือเลือก token ที่มีความน่าจะเป็นสูงสุดในแต่ละรอบ)
ฟังก์ชั่นนี้รับ 4 inputs ดังนี้
model: โมเดล LLM ที่จะใช้ generate คำตอบ ในตัวอย่างนี้คือ Mistral-7Binput_ids: Tensor ที่เก็บ token ของ input sequence (เช่น query ที่เราต้องการถาม) ได้มาจากการแปลงข้อความเป็น token ด้วย tokenizerpast_key_values: หัวใจของ CAG! นั่นก็คือ cache ที่เก็บค่า attention (key/value) ที่คำนวณไว้แล้วจาก context เดิมmax_new_tokens: จำนวน token สูงสุดที่ต้องการให้ generate ใหม่ ค่า default ในตัวอย่างคือ 150
def generate(model, input_ids, past_key_values, max_new_tokens=150):
# กำหนด device (CPU/GPU) ให้ตรงกับโมเดล
device = model.model.embed_tokens.weight.device
# เอาไว้จำว่าความยาว input เดิมเท่าไหร่ (จะได้ตัด output เฉพาะส่วนที่ generate ใหม่)
origin_len = input_ids.shape[-1]
# token ของ query ที่จะเริ่ม generate
input_ids = input_ids.to(device)
# copy input_ids ไว้เป็นจุดเริ่มต้นของ output sequence (จะต่อ token ใหม่เข้าไปเรื่อยๆ)
output_ids = input_ids.clone()
# กำหนด token แรกที่จะใช้ generate (เริ่มจาก input_ids เดิม)
next_token = input_ids
# torch.no_grad() เพื่อไม่เก็บ gradient (ประหยัด resource)
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
# ได้ output ออกมาเป็น tokens ที่ถูก generate เป็นคำตอบของ query เรา
return output_ids[:, origin_len:]
อธิบายแต่ละขั้นตอนในฟังก์ชั่นนี้
2. วนลูป generate ทีละ token
3. คืนเฉพาะ token ที่ generate ใหม่ (ไม่รวม input เดิม)
สรุปสั้นๆ — ฟังก์ชันนี้ “ต่อยอด” จาก KV cache ที่มี context เดิมอยู่แล้ว และสร้างคำตอบทีละ token แบบ greedy (เลือกคำที่น่าจะเป็นไปได้มากสุดในแต่ละรอบ) การทำแบบนี้จะเร็วและประหยัด resource เพราะใช้ cache ไม่ต้อง reload context เดิมซ้ำยังไงล่ะ!
ขั้นตอนนี้เราจะทำ CLI Chatbot โดยใช้สิ่งที่เราเตรียมกันมาจาก steps ก่อนๆหน้า
if __name__ == "__main__":
print("\n=== PALO IT CAG Chatbot ===")
print("ถามอะไรก็ได้เกี่ยวกับ PALO IT (พิมพ์ 'exit' เพื่อออก)\n")
# สร้าง Loop เพื่อใช้เป็น chatbot
while True:
query = input("> ").strip()
if query.lower() in {"exit", "quit", "q"}:
print("ลาก่อน!")
break
if not query:
continue
# clean up cache ก่อนทุกๆครั้งที่มี query เข้ามา
clean_up(_kv_cache, _origin_len)
# แปลง query ให้เป็น token โดยการ tokenize
input_ids = _tokenizer(query + "\n", return_tensors="pt").input_ids.to(_device)
# จับเวลาเริ่มที่ CAG ใช้ generate คำตอบ
start_gen = time.time()
# generate คำตอบโดยใช้ฟังก์ชั่น generate
output_ids = generate(_model, input_ids, _kv_cache, max_new_tokens=100)
# จับเวลาจบที่ CAG ใช้ generate คำตอบ
end_gen = time.time()
# decode คำตอบที่ถูกสร้างจาก function generate จาก tokens ให้เป็นคำตอบ
answer = _tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\n[CAG] {answer.strip()} (ตอบใน {end_gen - start_gen:.2f} วินาที)\n")
และนี่คือภาพรวมของ cag.py
import os
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from dotenv import load_dotenv
load_dotenv()
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
HF_TOKEN = os.getenv("HF_TOKEN")
# ฟังก์ชันนี้ช่วยเลือกว่าจะรันบน GPU (เร็วสุด) หรือ CPU (ถ้าไม่มี GPU)
# และเลือก dtype(data type)ให้เหมาะสม
def get_device_and_dtype():
if torch.cuda.is_available():
return torch.device("cuda"), torch.float16
else:
return torch.device("cpu"), torch.float32
# โหลด model และ tokenizer เพื่อนำมาสร้าง tokens
def load_model_and_tokenizer(model_name, hf_token):
device, dtype = get_device_and_dtype()
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
token=hf_token
)
model.to(device)
return tokenizer, model, device
# เราจะได้ tokenizer, model และ device เป็น output ของ function นี้
_tokenizer, _model, _device = load_model_and_tokenizer(MODEL_NAME, HF_TOKEN)
# เปิดและอ่านไฟล์ document.txt
def load_knowledge_for_cag(path="document.txt"):
with open(path, "r", encoding="utf-8") as f:
return f.read()
print("Loading knowledge...")
faq_text = load_knowledge_for_cag()
print(f"Loaded ({len(faq_text)} characters)")
# สร้าง system prompt
print("[CAG] Preparing system prompt...")
system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
<|user|>
Context:
{faq_text}
Question:
""".strip()
# ฟังก์ชันที่ใช้สร้าง KV cache สำหรับ context ของเรา
def get_kv_cache(model, tokenizer, prompt: str):
print("[CAG] Building KV cache...")
t1 = time.time()
# บอกว่า embedding layer ของโมเดลนี้อยู่บน device ไหน (เช่น "cuda:0" หรือ "cpu")
# เพื่อให้แน่ใจว่า input IDs ที่เราสร้างขึ้นจะถูกส่งไปยัง device เดียวกับโมเดล (ถ้าไม่ตรงกันจะ error)
device = model.model.embed_tokens.weight.device
# ทำการ tokenize prompt ของเราให้เป็น input IDs
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# object ที่ใช้เก็บ key-value pairs (kv cache)
cache = DynamicCache()
# context manager ของ PyTorch ที่บอกว่า "ในบล็อกนี้ ไม่ต้องเก็บ gradient"
# ปกติการรันโมเดลจะเก็บ gradient สำหรับการเทรน แต่สำหรับ inference หรือการสร้าง cache เราไม่ต้องการ gradient
with torch.no_grad():
_ = model(input_ids=input_ids, past_key_values=cache, use_cache=True)
t2 = time.time()
print(f"[CAG] KV cache built. Length: {input_ids.shape[-1]}. Time: {t2-t1:.2f}s")
return cache, input_ids.shape[-1]
# ได้ kv cache, original length ของ cache เป็น output
_kv_cache, _origin_len = get_kv_cache(_model, _tokenizer, system_prompt)
# ฟังก์ชั่นที่ใช้ reset cache
def clean_up(cache: DynamicCache, origin_len: int):
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
def generate(model, input_ids, past_key_values, max_new_tokens=150):
# กำหนด device (CPU/GPU) ให้ตรงกับโมเดล
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]
# token ของ query ที่จะเริ่ม generate
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
# torch.no_grad() เพื่อไม่เก็บ gradient (ประหยัด resource)
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
# ได้ output ออกมาเป็น tokens ที่ถูก generate เป็นคำตอบของ query เรา
return output_ids[:, origin_len:]
if __name__ == "__main__":
print("\n=== CAG Chatbot ===")
print("(type 'exit' to quit)\n")
# สร้าง Loop เพื่อใช้เป็น chatbot
while True:
query = input("> ").strip()
if query.lower() in {"exit", "quit", "q"}:
print("ลาก่อน!")
break
if not query:
continue
# clean up cache ก่อนทุกๆครั้งที่มี query เข้ามา
clean_up(_kv_cache, _origin_len)
# แปลง query ให้เป็น token โดยการ tokenize
input_ids = _tokenizer(query + "\n", return_tensors="pt").input_ids.to(_device)
# จับเวลาเริ่มที่ CAG ใช้ generate คำตอบ
start_gen = time.time()
# generate คำตอบโดยใช้ฟังก์ชั่น generate
output_ids = generate(_model, input_ids, _kv_cache, max_new_tokens=100)
# จับเวลาจบที่ CAG ใช้ generate คำตอบ
end_gen = time.time()
# decode คำตอบที่ถูกสร้างจาก function generate จาก tokens ให้เป็นคำตอบ
answer = _tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\n[AI] {answer.strip()} (answer in {end_gen - start_gen:.2f} second(s))\n")
. . .มาลองรันเพื่อทดสอบกัน !
ในตัวอย่างนี้จะใช้ Google Colab เพื่อรันทดสอบ CAG บน GPU A100 ซึ่งช่วยให้โหลดและรันโมเดลขนาดใหญ่ได้เร็วและไม่ติดปัญหาเรื่อง RAM เพราะมี VRAM สูง
รัน script ด้วย command นี้: python3 cag.py
จากในรูปจะเห็นได้ว่าเมื่อ script โหลดโมเดลและเตรียม cache ที่เป็น company FAQ ของ PALO-IT เรียบร้อยแล้ว และลองป้อน query ว่า “What is PALO IT?”
คำตอบที่เราได้ออกมาคืออ
query: What is PALO IT?
answer: PALO IT is a global technology consulting firm that helps organizations
transform their businesses through sustainable digital solutions.
They specialize in agile development, human-centered design,
and emerging technologies to create positive impact. (ตอบใน 2.07 วินาที)
ลองถามอีกซักหนึ่งคำถาม. . .
CAG ของเราสามารถตอบคำถามเกี่ยวกับ PALO-IT FAQ ได้เร็วและถูกต้อง เพราะว่า knowledge source ของเราถูกแปลงไปเป็น KV cache และนำมาใช้ตอบคำถามนั่นเอง! ✨ ✅
. . .มาสรุปสิ่งที่เราได้เรียนรู้มากัน
CAG — จะสร้างคำตอบได้เร็วกว่า RAG เพราะว่า CAG ไม่มีขา Information Retrieval (IR) เพราะข้อมูลทั้งหมด (knowledge source) ถูก preload เข้า context window และ KV cache ตั้งแต่แรก
แต่. . .CAG จะไม่เหมาะกับข้อมูลที่มีปริมาณมากเกินกว่าที่ context window ของ LLMs รับได้ เพราะว่าข้อมูลที่ถูกโหลดเข้าไปจะไม่ครบ และทำให้คำตอบที่ generate ออกมานั้นมีคุณภาพที่ไม่ดีพอ และ CAG เหมาะสมกับ static data หรือข้อมูลที่ไม่ค่อยถูกอัพเดทบ่อยทั้งหลาย เช่น FAQ, เมนูร้านอาหาร, คู่มือการใช้งาน product เป็นต้น
ดังนั้น. . . ถ้า knowledge source ของเราเล็กพอที่จะอัดเข้า context window ได้หมด และไม่เปลี่ยนบ่อย CAG คือทางเลือกที่ตอบโจทย์ — ตอบไว ประหยัด resource เพราะไม่ต้องแนบ relevant context หรือบริบทที่เกี่ยวข้องไปพร้อมกับทุกๆ queries
แต่ถ้าต้องการรองรับข้อมูลขนาดใหญ่, ข้อมูลที่อัปเดตบ่อย, หรืออยากได้ citation อ้างอิง RAG จะเหมาะสมกว่า เพราะสามารถดึงเฉพาะข้อมูลที่เกี่ยวข้องมาตอบได้โดยไม่ติดข้อจำกัดของ context window หรือในบาง use case อาจใช้ hybrid approach คือ RAG ดึง context เฉพาะที่เกี่ยวข้อง แล้วค่อยใช้ CAG ตอบใน multi-turn chat ก็ได้เหมือนกัน
มีของแถม !! ผู้เขียนได้ทำ benchmark เพื่อเปรียบเทียบ performance ระหว่าง RAG และ CAG ใน knowledge source เดียวกัน! โดย flow จะเป็นดังนี้
จากรูปด้านบน เมื่อมี query เข้ามา ระบบจะส่งต่อ query ไปยัง CAG และ RAG module เพื่อประมวลผลพร้อมกัน
สำหรับ RAG module — จะเริ่มจากการนำเอกสาร PALO-IT FAQ มาเตรียมการก่อน โดยแบ่งเนื้อหาออกเป็นส่วนย่อยๆ (chunk) จากนั้นแปลงแต่ละส่วนให้อยู่ในรูปแบบ vector (embed) แล้วจัดเก็บไว้ในฐานข้อมูล Qdrant ซึ่งเป็น vector database เมื่อมีคำถามเข้ามา RAG จะค้นหาเฉพาะข้อมูลที่เกี่ยวข้องจาก Qdrant แล้วนำข้อมูลนั้นพร้อมกับคำถามส่งต่อให้ LLM สร้างคำตอบ โดยที่เริ่มจับเวลาตั้งแต่ขา retrieval ไปจนถึง generation
ในขณะที่ CAG module ทำงานในอีกรูปแบบ คือจะโหลดเอกสาร PALO-IT FAQ ทั้งหมดเข้าไปเก็บไว้ใน context window ของ LLM ตั้งแต่เริ่มต้น เมื่อมีคำถามเข้ามา LLM สามารถใช้ข้อมูลที่เก็บไว้แล้วในการสร้างคำตอบได้ทันที โดยไม่ต้องผ่านขั้นตอนการค้นหาเพิ่มเติม โดยที่เริ่มจับเวลาตอนที่ LLM เริ่ม generate คำตอบ
และนี่คือผลลัพธ์ของ benchmark นี้
เราสามารถเห็นได้ว่า CAG ตอบคำถามได้เร็วกว่า RAG จริงๆใน query และ knowledge source เดียวกัน !! 💨
คำถามส่งท้ายไว้ลองคิดตาม: scenario นี้ เลือก RAG หรือ CAG ดี ?
ลองคิดตามสถานการณ์เหล่านี้:
เฉลยอยู่ด้านล่างนี้ 👇🏼
สิ่งที่เราได้เรียนรู้จาก content นี้ก็คือ alternative approach of RAG หรือ CAG ที่สามารถทำให้ LLMs ตอบคำถามที่เฉพาะเจาะจง (based on knowledge source ของเรา) โดยที่ไม่ต้องใช้ Information Retrieval เพื่อไปดึงบริบทที่เกี่ยวข้องเหมือนกับ RAG นั่นเอง. . .
Key Takeaway — ไม่มี approach ไหนดีที่สุด ขึ้นอยู่กับ use case และความต้องการของระบบ! เลือกให้เหมาะกับงาน อย่าติดกับดัก “one-size-fits-all” 😉
สำหรับใครที่กำลังมองหาวิธีสร้าง RAG, CAG Application หรือ Chatbot เพื่อใช้งานในองค์กร ที่ PALO IT เรามีทีมผู้เชี่ยวชาญพร้อมช่วยตั้งแต่เริ่มต้นจนระบบใช้งานได้จริง! ไม่ว่าจะเป็น
ไม่ว่าคุณจะเพิ่งเริ่มต้น หรือมีระบบอยู่แล้วและอยากต่อยอด เราพร้อมเป็น partner ที่จะช่วยให้คุณไปได้ไกลกว่าเดิม
ทักไปที่เพจ Facebook: PALO IT Thailand ได้เลยครับ 🎉