Source code for aita_core.ingest

"""
Document ingestion utilities for AITA.

Standard usage (default directory layout):

    from config import CONFIG
    from aita_core.ingest import run_ingestion
    run_ingestion(CONFIG)

Custom collectors (non-standard layout):

    from config import CONFIG
    from aita_core.ingest import run_ingestion

    def my_collect_handouts(config):
        ...
        return docs

    run_ingestion(CONFIG, collectors=[
        ("handouts", my_collect_handouts),
        ("homework", my_collect_homework),
    ])
"""

import os
import re
import pickle
import shutil
import datetime

import numpy as np
import faiss
from openai import OpenAI

from aita_core.utils import save_docs_to_jsonl


_client = None


def _get_client():
    global _client
    if _client is None:
        _client = OpenAI()
    return _client


# ---------------------------------------------------------------------------
# Week tagging
# ---------------------------------------------------------------------------

[docs] def get_week_for_filename(filename, topic_num_to_week, hw_num_to_week, lab_num_to_week, study_guide_to_week): """Determine which week a document belongs to based on its filename.""" match = re.match(r"^(\d+)\s", filename) if match: topic_num = int(match.group(1)) return topic_num_to_week.get(topic_num, 15) match = re.match(r"^HW(\d+)", filename) if match: hw_num = int(match.group(1)) return hw_num_to_week.get(hw_num, 15) match = re.match(r"^Lab\s*(\d+)", filename) if match: lab_num = int(match.group(1)) return lab_num_to_week.get(lab_num, 15) for guide_key, week in sorted(study_guide_to_week.items(), key=lambda x: -len(x[0])): if filename.startswith(guide_key): return week if "syllabus" in filename.lower(): return 1 return 1
# --------------------------------------------------------------------------- # Text extraction # --------------------------------------------------------------------------- def _extract_pdf_text(file_path): """Extract text from PDF using pymupdf (preferred) or pdfminer (fallback).""" try: import pymupdf doc = pymupdf.open(file_path) pages = [] for page in doc: pages.append(page.get_text("text")) doc.close() return "\n".join(pages) except ImportError: from pdfminer.high_level import extract_text text = extract_text(file_path) # pdfminer cleanup: collapse whitespace for readability text = text.replace("\x0c", " ") text = text.replace("\n", " ") text = re.sub(r"\s{2,}", " ", text) return text.strip()
[docs] def load_pdf(file_path, source_label, max_week=1): text = _extract_pdf_text(file_path) if not text or not text.strip(): return [] return [{"text": text.strip(), "metadata": {"source": file_path, "source_label": source_label, "max_week": max_week}}]
[docs] def load_tex(file_path, source_label, max_week=1): with open(file_path, "r", encoding="utf-8", errors="ignore") as f: content = f.read() content = re.sub(r"(?m)^%.*$", "", content) content = re.sub( r"\\(documentclass|usepackage|begin\{document\}|end\{document\}|maketitle|input\{[^}]*\}|newcommand[^}]*\}[^}]*\})", "", content, ) content = content.strip() if not content: return [] return [{"text": content, "metadata": {"source": file_path, "source_label": source_label, "max_week": max_week}}]
[docs] def load_wikibook_page(url): """Fetch a Wikibook page and extract clean text content.""" import urllib.request from html.parser import HTMLParser class _TextExtractor(HTMLParser): def __init__(self): super().__init__() self.in_content = False self.skip = False self.text = [] self.skip_tags = {"script", "style", "sup", "nav"} def handle_starttag(self, tag, attrs): attrs_dict = dict(attrs) if attrs_dict.get("id") == "mw-content-text": self.in_content = True if tag in self.skip_tags: self.skip = True def handle_endtag(self, tag): if tag in self.skip_tags: self.skip = False if tag in ("p", "h2", "h3", "h4", "li"): self.text.append("\n") def handle_data(self, data): if self.in_content and not self.skip: self.text.append(data) req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"}) with urllib.request.urlopen(req) as resp: html = resp.read().decode("utf-8", errors="ignore") parser = _TextExtractor() parser.feed(html) text = "".join(parser.text) text = re.sub(r"\n{3,}", "\n\n", text) text = re.sub(r"[ \t]+", " ", text) return text.strip()
[docs] def collect_wikibook(config): """Collect chapters from a Wikibook URL. Requires config to have: textbook_url: str — base URL, e.g. "https://en.wikibooks.org/wiki/Fundamentals_of_Transportation" textbook_chapter_to_week: dict[str, int] — maps chapter slug to week, e.g. {"Trip_Generation": 2, "Route_Choice": 4} """ import urllib.request url = getattr(config, "textbook_url", None) chapter_to_week = getattr(config, "textbook_chapter_to_week", None) if not url or not chapter_to_week: print(" Skipping: textbook_url or textbook_chapter_to_week not set") return [] docs = [] for chapter_slug, week in sorted(chapter_to_week.items(), key=lambda x: x[1]): chapter_url = f"{url.rstrip('/')}/{chapter_slug}" chapter_name = chapter_slug.replace("_", " ") label = f"Textbook: {chapter_name}" print(f" Loading {label} (week {week})") try: text = load_wikibook_page(chapter_url) if text: docs.append({ "text": text, "metadata": { "source": chapter_url, "source_label": label, "max_week": week, }, }) except Exception as e: print(f" Error fetching {chapter_url}: {e}") return docs
# --------------------------------------------------------------------------- # Chunking # ---------------------------------------------------------------------------
[docs] def chunk_text(text, chunk_size=2048, overlap=256): """Split text into overlapping chunks.""" chunks = [] start = 0 while start < len(text): end = start + chunk_size chunks.append(text[start:end]) start += chunk_size - overlap return chunks
[docs] def chunk_documents(docs, chunk_size=2048, overlap=256): """Split all documents into chunks, preserving metadata.""" all_chunks = [] for doc in docs: text_chunks = chunk_text(doc["text"], chunk_size, overlap) for chunk in text_chunks: label = doc["metadata"]["source_label"] all_chunks.append({ "text": f"Source: {label}\n{chunk}", "metadata": doc["metadata"], }) return all_chunks
# --------------------------------------------------------------------------- # Embeddings # ---------------------------------------------------------------------------
[docs] def get_embeddings(texts, embedding_model="text-embedding-3-large", batch_size=100): """Call OpenAI embeddings API in batches. Returns numpy array.""" client = _get_client() all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] print(f" Embedding batch {i // batch_size + 1}/{(len(texts) - 1) // batch_size + 1} ({len(batch)} chunks)") response = client.embeddings.create(model=embedding_model, input=batch) batch_embeddings = [item.embedding for item in response.data] all_embeddings.extend(batch_embeddings) return np.array(all_embeddings, dtype="float32")
# --------------------------------------------------------------------------- # FAISS index management # ---------------------------------------------------------------------------
[docs] def build_faiss_index(embeddings): """Build a FAISS index from embeddings.""" dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) faiss.normalize_L2(embeddings) index.add(embeddings) return index
[docs] def save_index(index, chunks, faiss_dir, docs_dir, backup_dir): """Save FAISS index and chunk metadata, with backup of existing data.""" os.makedirs(faiss_dir, exist_ok=True) os.makedirs(docs_dir, exist_ok=True) os.makedirs(backup_dir, exist_ok=True) index_path = os.path.join(faiss_dir, "index.faiss") meta_path = os.path.join(faiss_dir, "metadata.pkl") doc_jsonl = os.path.join(docs_dir, "doc.jsonl") timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M") if os.path.exists(index_path): print("Existing index found. Backing up...") bak = os.path.join(backup_dir, timestamp) os.makedirs(bak, exist_ok=True) shutil.copy2(index_path, os.path.join(bak, "index.faiss")) if os.path.exists(meta_path): shutil.copy2(meta_path, os.path.join(bak, "metadata.pkl")) if os.path.exists(doc_jsonl): shutil.copy2(doc_jsonl, os.path.join(bak, "doc.jsonl")) faiss.write_index(index, index_path) with open(meta_path, "wb") as f: pickle.dump(chunks, f) save_docs_to_jsonl(chunks, doc_jsonl) print(f"Index saved to {faiss_dir} ({index.ntotal} vectors)") print(f"Document records saved to {doc_jsonl}")
# --------------------------------------------------------------------------- # Default collectors (standard directory layout) # --------------------------------------------------------------------------- def _week_for(filename, config): return get_week_for_filename( filename, config.topic_num_to_week, config.hw_num_to_week, config.lab_num_to_week, config.study_guide_to_week, )
[docs] def collect_syllabus(config_or_dir): """Load syllabus from standard location. Accepts either a CourseConfig or a course_materials_dir path for backwards compatibility. """ if isinstance(config_or_dir, str): course_materials_dir = config_or_dir else: course_materials_dir = config_or_dir.course_materials_dir docs = [] tex_path = os.path.join(course_materials_dir, "syllabus", "Syllabus.tex") pdf_path = os.path.join(course_materials_dir, "syllabus", "Syllabus.pdf") if os.path.exists(tex_path): print(" Loading Syllabus (LaTeX, week 1)") docs.extend(load_tex(tex_path, "Syllabus", max_week=1)) elif os.path.exists(pdf_path): print(" Loading Syllabus (PDF, week 1)") docs.extend(load_pdf(pdf_path, "Syllabus", max_week=1)) return docs
[docs] def collect_handouts(config): """Load handout PDFs from Handouts/Handouts/.""" docs = [] handouts_dir = os.path.join(config.course_materials_dir, "Handouts", "Handouts") if not os.path.isdir(handouts_dir): print(f" Warning: {handouts_dir} not found") return docs for filename in sorted(os.listdir(handouts_dir)): if not filename.endswith(".pdf"): continue file_path = os.path.join(handouts_dir, filename) label = f"Handout: {filename}" week = _week_for(filename, config) print(f" Loading {label} (week {week})") docs.extend(load_pdf(file_path, label, max_week=week)) return docs
[docs] def collect_homework(config): """Load homework PDFs from Homework handouts/Homework handouts/, skipping solutions.""" docs = [] hw_dir = os.path.join(config.course_materials_dir, "Homework handouts", "Homework handouts") if not os.path.isdir(hw_dir): print(f" Warning: {hw_dir} not found") return docs for filename in sorted(os.listdir(hw_dir)): if not filename.endswith(".pdf"): continue if "solution" in filename.lower(): print(f" Skipping (solution): {filename}") continue file_path = os.path.join(hw_dir, filename) label = f"Homework: {filename}" week = _week_for(filename, config) print(f" Loading {label} (week {week})") docs.extend(load_pdf(file_path, label, max_week=week)) return docs
[docs] def collect_slides(config): """Load slide content from Slides/Slides/<topic>/ (content.tex or Notes.pdf).""" docs = [] slides_dir = os.path.join(config.course_materials_dir, "Slides", "Slides") if not os.path.isdir(slides_dir): print(f" Warning: {slides_dir} not found") return docs for topic_name in sorted(os.listdir(slides_dir)): topic_path = os.path.join(slides_dir, topic_name) if not os.path.isdir(topic_path): continue label = f"Slides: {topic_name}" week = _week_for(topic_name, config) content_tex = os.path.join(topic_path, "content.tex") if os.path.exists(content_tex): print(f" Loading {label} (LaTeX, week {week})") docs.extend(load_tex(content_tex, label, max_week=week)) else: notes_pdf = os.path.join(topic_path, "Notes.pdf") if os.path.exists(notes_pdf): print(f" Loading {label} (PDF, week {week})") docs.extend(load_pdf(notes_pdf, label, max_week=week)) return docs
# --------------------------------------------------------------------------- # Ingestion pipeline runner # ---------------------------------------------------------------------------
[docs] def run_ingestion(config, collectors=None): """Run the full document ingestion pipeline. Args: config: CourseConfig instance. collectors: Optional list of (name, callable) pairs. Each callable receives config and returns a list of docs. If None, uses default collectors for the standard directory layout. """ if collectors is None: collectors = [ ("lecture handouts", collect_handouts), ("homework questions", collect_homework), ("slide content", collect_slides), ("syllabus", collect_syllabus), ] if getattr(config, "textbook_url", "") and getattr(config, "textbook_chapter_to_week", {}): collectors.append(("textbook", collect_wikibook)) total = len(collectors) print("=" * 60) print(f"AITA {config.course_id} Document Ingestion Pipeline") print("=" * 60) all_docs = [] for i, (name, collector_fn) in enumerate(collectors, 1): print(f"\n[{i}/{total}] Collecting {name}...") all_docs += collector_fn(config) if not all_docs: print("\nNo documents found. Check course_materials directory.") return print(f"\nTotal documents loaded: {len(all_docs)}") chunks = chunk_documents(all_docs, config.chunk_size, config.chunk_overlap) print(f"Total chunks after splitting: {len(chunks)}") print(f"\nGenerating embeddings with {config.embedding_model}...") texts = [c["text"] for c in chunks] embeddings = get_embeddings(texts, config.embedding_model) print("\nBuilding FAISS index...") index = build_faiss_index(embeddings) save_index(index, chunks, config.faiss_db_dir, config.docs_dir, config.backup_dir) print("\nDone! Vector store is ready.")