Source code for aita_core.rag

"""
RAG pipeline for AITA.
No LangChain — just openai + faiss directly.
"""

import os
import pickle
import re

import numpy as np
import faiss
from openai import OpenAI

from aita_core.config import get_config

# Lazy-loaded state
_client = None
_index = None
_chunks = None


[docs] def _get_client(): global _client if _client is None: _client = OpenAI() return _client
[docs] def _load_index(): global _index, _chunks if _index is None: cfg = get_config() _index = faiss.read_index(os.path.join(cfg.faiss_db_dir, "index.faiss")) with open(os.path.join(cfg.faiss_db_dir, "metadata.pkl"), "rb") as f: _chunks = pickle.load(f)
FUTURE_TOPIC_INSTRUCTION = """ IMPORTANT — WEEK-AWARE INSTRUCTION: The class is currently in Week {current_week}. {current_hw_line} Topics covered so far: {covered_topics} Topics NOT yet covered: {future_topics} STRICT RULE — FUTURE TOPICS: If the student asks about a topic that has NOT been covered yet, you MUST: 1. Say: "That's a great question! We'll cover that topic later in the course." 2. Give AT MOST a one-sentence definition — no formulas, no numbers, no calculations. 3. Do NOT provide specific values (like rates, speeds, constants) for future topics. 4. Redirect: "For now, let's focus on the current material. Is there anything about \ [current topic] I can help with?" This is an absolute rule. Even if you know the answer, do NOT provide detailed \ information about future topics. Giving incorrect or premature information is worse \ than redirecting the student. If the student asks about "this week's homework" or "the current homework", refer to {current_hw_ref}. If the student says "problem 1" or "problem 2" etc. without specifying which homework, \ assume they mean {current_hw_ref} and use the retrieved context from that homework. """ EXAM_SCOPE_INSTRUCTION = """ EXAM SCOPE INFORMATION: {exam_scope_text} CRITICAL RULE — EXAM STUDY GUIDES: When a student asks about preparing for a specific exam (e.g., "study guide for midterm 2", \ "practice exam", "what's on the midterm"), you MUST: - ONLY include topics that fall within that exam's scope as listed above. - Do NOT include topics from weeks beyond the exam's week range. - Base your study guide on the retrieved course materials, not your own knowledge. - If you are unsure which exam they mean, ask them to clarify. """ NO_CONTEXT_WARNING = """ WARNING — NO COURSE MATERIALS RETRIEVED: No course materials were found matching this query. This likely means the topic has not \ been covered yet, or the query does not match any course content. You MUST NOT provide detailed answers from your own knowledge — they may be incorrect. Instead, check if the topic appears in the "NOT yet covered" list above and redirect \ accordingly. If unsure, tell the student: "I don't have course materials on this topic \ yet. Could you rephrase your question, or is this a topic we haven't covered?" """
[docs] def build_system_prompt(current_week, has_context=True): """Build system prompt with week-awareness and exam scope.""" cfg = get_config() covered = cfg.get_topics_covered(current_week) future = cfg.get_topics_not_covered(current_week) week_to_hw = cfg.week_to_hw current_hw = week_to_hw.get(current_week, None) if not current_hw: for w in range(current_week, 0, -1): if w in week_to_hw: current_hw = week_to_hw[w] break current_hw = current_hw or "the most recent homework" current_hw_line = f"The current homework assignment is: {current_hw}" current_hw_ref = current_hw prompt = cfg.system_prompt prompt += "\n\n" + FUTURE_TOPIC_INSTRUCTION.format( current_week=current_week, covered_topics=", ".join(covered), future_topics=", ".join(future) if future else "None (all topics covered)", current_hw_line=current_hw_line, current_hw_ref=current_hw_ref, ) # Add exam scope if configured if cfg.exam_scope: lines = [] for exam_name, scope in sorted(cfg.exam_scope.items()): topics = cfg.get_exam_topics(exam_name) if topics: lines.append( f"- {exam_name} (weeks {scope['week_start']}-{scope['week_end']}): " f"{', '.join(topics)}" ) if lines: prompt += "\n\n" + EXAM_SCOPE_INSTRUCTION.format( exam_scope_text="\n".join(lines), ) # Warn when no context was retrieved if not has_context: prompt += "\n\n" + NO_CONTEXT_WARNING return prompt
[docs] def retrieve(query, k=None, current_week=15): """Retrieve top-k relevant chunks, filtered to only topics covered by current_week.""" _load_index() cfg = get_config() client = _get_client() if k is None: k = cfg.retrieval_k resp = client.embeddings.create(model=cfg.embedding_model, input=[query]) qvec = np.array([resp.data[0].embedding], dtype="float32") faiss.normalize_L2(qvec) fetch_k = min(k * 4, _index.ntotal) scores, indices = _index.search(qvec, fetch_k) results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1: continue chunk_week = _chunks[idx]["metadata"].get("max_week", 1) if chunk_week > current_week: continue results.append({ "text": _chunks[idx]["text"], "source": _chunks[idx]["metadata"]["source_label"], "file_path": _chunks[idx]["metadata"].get("source", ""), "score": float(score), }) if len(results) >= k: break return results
[docs] def build_messages(chat_history, user_query, context_chunks, current_week): """Build the message list for the OpenAI chat completion.""" context = "\n\n---\n\n".join(c["text"] for c in context_chunks) system_prompt = build_system_prompt( current_week, has_context=bool(context_chunks), ) messages = [ {"role": "system", "content": system_prompt + f"\n\nRetrieved course materials:\n{context}"}, ] for msg in chat_history[-20:]: messages.append(msg) messages.append({"role": "user", "content": user_query}) return messages
[docs] def _inject_current_hw(query, context_chunks, current_week): """If the query mentions homework, ensure the current HW is in retrieved chunks.""" _load_index() cfg = get_config() hw_keywords = ["homework", "hw", "assignment", "this week's hw", "current hw"] if not any(kw in query.lower() for kw in hw_keywords): return context_chunks week_to_hw = cfg.week_to_hw current_hw = week_to_hw.get(current_week) if not current_hw: for w in range(current_week, 0, -1): if w in week_to_hw: current_hw = week_to_hw[w] break if not current_hw: return context_chunks # Check if current HW is already in results hw_label = f"Homework: {current_hw}.pdf" if any(hw_label in c.get("source", "") for c in context_chunks): return context_chunks # Find and inject the first chunk from the current HW for i, chunk in enumerate(_chunks): label = chunk["metadata"].get("source_label", "") if current_hw in label and "Homework" in label: chunk_week = chunk["metadata"].get("max_week", 1) if chunk_week <= current_week: context_chunks.insert(0, { "text": chunk["text"], "source": label, "file_path": chunk["metadata"].get("source", ""), "score": 1.0, }) break return context_chunks
def _identify_exam(query_lower, cfg): """Identify which exam the student is asking about from the query text.""" if not cfg.exam_scope: return None # Try exact name matches first (e.g., "midterm 1", "midterm 2", "final") for exam_name in sorted(cfg.exam_scope.keys()): name_lower = exam_name.lower() if name_lower in query_lower: return exam_name # Handle "midterm exam 2" matching "Midterm 2" parts = name_lower.split() if len(parts) >= 2 and all(p in query_lower for p in parts): return exam_name # Match "final" keyword if "final" in query_lower: for name in cfg.exam_scope: if "final" in name.lower(): return name # Match generic "midterm" — pick the most relevant one if "midterm" in query_lower: # Check for a number in the query match = re.search(r"midterm\s*(?:exam\s*)?(\d+)", query_lower) if match: target = f"Midterm {match.group(1)}" if target in cfg.exam_scope: return target # No number — return the latest midterm (students usually ask about upcoming) midterms = sorted( [(n, s) for n, s in cfg.exam_scope.items() if "midterm" in n.lower()], key=lambda x: x[1].get("week_end", 0), ) if midterms: return midterms[-1][0] return None def _inject_exam_review(query, context_chunks, current_week): """For exam-related queries, retrieve topic-relevant content using exam scope.""" cfg = get_config() if not cfg.exam_scope: return context_chunks exam_keywords = [ "midterm", "study guide", "review for", "practice exam", "final exam", "prepare for exam", "what's on the exam", "what will be on", "exam review", ] query_lower = query.lower() if not any(kw in query_lower for kw in exam_keywords): return context_chunks target_exam = _identify_exam(query_lower, cfg) if not target_exam: return context_chunks topics = cfg.get_exam_topics(target_exam) if not topics: return context_chunks # Retrieve chunks using exam topics as a synthetic query topic_query = f"Key concepts for exam review: {', '.join(topics)}" topic_results = retrieve(topic_query, k=cfg.retrieval_k, current_week=current_week) # Merge: add new unique sources from topic retrieval existing_sources = {c["source"] for c in context_chunks} for chunk in topic_results: if chunk["source"] not in existing_sources: context_chunks.append(chunk) existing_sources.add(chunk["source"]) return context_chunks
[docs] def chat(user_query, chat_history=None, current_week=15): """ Full RAG pipeline: retrieve context, build prompt, generate response. Returns (assistant_message, sources). """ cfg = get_config() client = _get_client() if chat_history is None: chat_history = [] context_chunks = retrieve(user_query, current_week=current_week) context_chunks = _inject_current_hw(user_query, context_chunks, current_week) context_chunks = _inject_exam_review(user_query, context_chunks, current_week) seen = set() sources = [] for c in context_chunks: if c["source"] not in seen: seen.add(c["source"]) sources.append({"label": c["source"], "file_path": c["file_path"]}) messages = build_messages(chat_history, user_query, context_chunks, current_week) response = client.chat.completions.create( model=cfg.llm_model, messages=messages, temperature=cfg.llm_temperature, ) assistant_message = response.choices[0].message.content return assistant_message, sources