Source code for starling.search.builder

"""
FAISS Index Builder
===================

Build high-performance FAISS indexes with OPQ+IVF-PQ compression and integrated
SQLite sequence metadata storage.

Overview
--------
The IndexBuilder creates FAISS indexes optimized for billion-scale similarity search:

* **OPQ (Optimized Product Quantization)**: Learned rotation for better quantization
* **IVF (Inverted File)**: Clustering for fast approximate search
* **PQ (Product Quantization)**: Compression to ~64-128 bytes per vector
* **Sequence Store**: SQLite database with sequences, headers, and metadata

Basic Usage
-----------
>>> from starling.search import IndexBuilder
>>>
>>> # Initialize and discover shards
>>> builder = IndexBuilder(
...     root="/path/to/shards",
...     metric="cosine",
...     verbose=True
... )
>>>
>>> # Build index with sequence store
>>> builder.build_index(
...     index_path="my_index.faiss",
...     tokens_dir="/path/to/tokens",
...     sample_size=655360,     # Training samples
...     nlist=16384,            # IVF clusters
...     m=64,                   # PQ subquantizers
...     nbits=8,                # Bits per subquantizer
...     use_gpu=True,           # GPU acceleration
...     use_opq=True,           # Enable OPQ
...     compress_sequences=True # Compress sequences with zstd
... )

Index Configuration
-------------------
**Training Parameters:**

* ``sample_size``: Number of vectors for training (default: 655,360)

  - Larger = better quantization quality
  - Minimum: 39 * nlist vectors required
  - Recommended: 100-1000 vectors per IVF cluster

* ``nlist``: Number of IVF clusters (default: 16,384)

  - More clusters = faster search but lower recall
  - Rule of thumb: sqrt(N) to N/1000 where N is total vectors
  - Auto-adjusted if training samples insufficient

* ``m``: PQ subquantizers (default: 64)

  - Must divide vector dimension evenly
  - More = better quality but larger memory footprint
  - Typical: 32-128 depending on dimension

* ``nbits``: Bits per subquantizer (default: 8)

  - Controls codebook size: 2^nbits entries per subquantizer
  - 8 bits = 256 entries (standard)
  - 16 bits = 65536 entries (higher quality, more memory)

**GPU Options:**

* ``use_gpu``: Train on GPU
* ``gpu_device``: CUDA device ID
* ``gpu_fp16_lut``: Use float16 lookup tables (to address 48KB SMEM limit)

**Sequence Store:**

* ``tokens_dir``: Directory with tokenized sequences (.tokens.pt files)
* ``compress_sequences``: Use zstd compression

File Structure
--------------
The builder expects sharded feature files in this structure::

    root/
    ├── uniref50_idrs_only_000000/
    │   └── sequence_features.pt
    ├── uniref50_idrs_only_000001/
    │   └── sequence_features.pt
    └── ...

And produces these outputs::

    my_index.faiss              # FAISS index
    my_index.faiss.manifest.json # Metadata
    my_index.faiss.seqs.sqlite  # Sequence store (if tokens_dir provided)


See Also
--------
* :class:`SearchEngine`: Query the built index
* :class:`SequenceStore`: Direct database access
* :mod:`starling.search.cli`: Command-line interface

Notes
-----
* Use ``verbose=True`` for progress tracking during long builds
* GPU training dramatically faster but index quality identical
"""

from __future__ import annotations

import glob
import json
import logging
import os
import re
import time
from typing import Any, List, Optional

import numpy as np
import torch
from tqdm import tqdm

from starling.search.store import SequenceStore

try:
    import faiss  # type: ignore
except Exception as e:
    raise ImportError(
        "FAISS is required. Install with 'pip install faiss-gpu' (CUDA) or 'pip install faiss-cpu'."
    ) from e

logger = logging.getLogger(__name__)

# Default shard ID regex (captures 6-digit zero-padded numeric id)
_SHARD_ID_RE_DEFAULT = re.compile(r"uniref50_idrs_only_(\d{6})")


def _parse_shard_id_from_path(path: str) -> int:
    """
    Extract numeric shard id from a features path like:
    .../uniref50_idrs_only_000123/sequence_features.pt

    Returns an int (e.g., 123). Raises if not found.
    """
    m = _SHARD_ID_RE_DEFAULT.search(path)
    if not m:
        raise ValueError(f"Could not parse shard id from path: {path}")
    return int(m.group(1))


[docs] class IndexBuilder: """ Builds FAISS OPQ+IVF-PQ indexes from sharded feature files. """
[docs] def __init__( self, root: str, metric: str = "cosine", verbose: bool = True, shard_id_regex: Optional[str] = None, ): if metric not in {"l2", "cosine"}: raise ValueError("metric must be 'l2' or 'cosine'") self.root = root self.metric = metric self.verbose = verbose if verbose: self._log = lambda msg, *a, **k: logger.info(str(msg)) else: self._log = lambda *a, **k: None # Compile shard id regex (user override or default) try: self._shard_id_re = ( re.compile(shard_id_regex) if shard_id_regex else _SHARD_ID_RE_DEFAULT ) except re.error as e: raise ValueError(f"Invalid shard_id_regex pattern: {shard_id_regex}: {e}") self._shard_id_regex_str = shard_id_regex or _SHARD_ID_RE_DEFAULT.pattern self.files: List[str] = [] self.shard_ids: List[int] = [] self.counts: List[int] = [] self.total: int = 0 self.dim: Optional[int] = None self._discover_files()
def _parse_shard_id(self, path: str) -> int: """Extract shard id using configured regex. Raises ValueError if not found.""" m = self._shard_id_re.search(path) if not m: raise ValueError( f"Could not parse shard id with pattern '{self._shard_id_regex_str}' from path: {path}" ) # Expect first capture group is the numeric id try: return int(m.group(1)) except Exception as e: raise ValueError( f"Shard id capture group not numeric in path: {path}" ) from e def _extract_features_from_data(self, data: Any, *, path: str) -> torch.Tensor: """Accepts various formats; returns (N, D) float32.""" if isinstance(data, torch.Tensor): t = data.float() if t.dim() == 1: t = t.unsqueeze(0) if t.dim() != 2: raise ValueError( f"Tensor in {path} must be 1D or 2D, got {tuple(t.shape)}" ) return t if isinstance(data, dict): # Case: data['features'] is a dict of header -> 1D/2D tensor if "features" in data and isinstance(data["features"], dict): feats_dict = data["features"] rows = [] for h, v in feats_dict.items(): if not isinstance(v, torch.Tensor): continue vt = v.float() if vt.dim() == 1: vt = vt.unsqueeze(0) if vt.dim() != 2: raise ValueError( f"Feature tensor for header '{h}' in {path} must be 1D or 2D" ) rows.append(vt) if not rows: raise ValueError( f"No tensor features found in features dict for {path}" ) return torch.cat(rows, dim=0) # Direct tensor under 'features' if "features" in data and isinstance(data["features"], torch.Tensor): return self._extract_features_from_data(data["features"], path=path) # Dict of header -> 1D tensor (legacy case) kv = [(k, v) for k, v in data.items() if isinstance(v, torch.Tensor)] if kv: vecs = [] for _, v in kv: v = v.float() if v.dim() == 1: v = v.unsqueeze(0) if v.dim() != 2 or v.size(0) != 1: raise ValueError(f"All feature entries must be 1D in {path}") vecs.append(v) return torch.cat(vecs, dim=0) if isinstance(data, list): # List of tensors -> stack if all(isinstance(x, torch.Tensor) for x in data) and data: proc = [] for v in data: v = v.float() if v.dim() == 1: v = v.unsqueeze(0) if v.dim() != 2 or v.size(0) != 1: raise ValueError( f"All tensors in list must be 1D; offending shape {tuple(v.shape)} in {path}" ) proc.append(v) return torch.cat(proc, dim=0) # List of dicts containing embedding tensor under common keys emb_keys = ("embedding", "feature", "tensor") rows = [] for el in data: if isinstance(el, dict): tensor = None for k in emb_keys: if k in el and isinstance(el[k], torch.Tensor): tensor = el[k] break if tensor is not None: t = tensor.float() if t.dim() == 1: t = t.unsqueeze(0) if t.dim() != 2 or t.size(0) != 1: raise ValueError( f"Embedded tensor must be 1D in list element of {path}" ) rows.append(t) else: # If dict has exactly one tensor value, use it tv = [v for v in el.values() if isinstance(v, torch.Tensor)] if len(tv) == 1: t = tv[0].float() if t.dim() == 1: t = t.unsqueeze(0) if t.dim() != 2 or t.size(0) != 1: raise ValueError( f"Tensor must be 1D in list element of {path}" ) rows.append(t) if rows: return torch.cat(rows, dim=0) raise ValueError(f"Unsupported .pt format in {path}") def _load_features(self, path: str) -> torch.Tensor: data = torch.load(path, map_location="cpu") return self._extract_features_from_data(data, path=path) def _discover_files(self) -> None: """Discover and inventory all feature shard files.""" t0 = time.time() pattern = os.path.join(self.root, "**", "sequence_features.pt") found_paths = glob.glob(pattern, recursive=True) records = [] for p in found_paths: try: sid = self._parse_shard_id(p) except ValueError: if self.verbose: self._log(f"[DISCOVER][SKIP] Could not parse shard id: {p}") continue arr = self._load_features(p) n, d = arr.shape records.append((sid, p, int(n), int(d))) # Sort strictly by numeric shard id to keep pairing stable records.sort(key=lambda t: t[0]) if self.verbose: self._log(f"[DISCOVER] {len(records)} shard files") # Populate class fields in the sorted order self.files, self.shard_ids, self.counts = [], [], [] self.dim = None for sid, p, n, d in records: if self.dim is None: self.dim = d elif self.dim != d: raise ValueError(f"Dim mismatch {self.dim} vs {d} in {p}") self.files.append(p) self.shard_ids.append(sid) self.counts.append(n) self.total = sum(self.counts) self._log( f"[DISCOVER] total_vectors={self.total} dim={self.dim} time={time.time() - t0:.2f}s" )
[docs] def sample_vectors( self, sample_size: int = 100_000, seed: int = 1234 ) -> torch.Tensor: """Randomly sample vectors across all shards for training. Parameters ---------- sample_size : int, optional Number of vectors to sample, by default 100_000 seed : int, optional Random seed for reproducibility, by default 1234 Returns ------- torch.Tensor Sampled vectors of shape (sample_size, dim) Raises ------ RuntimeError If no samples are collected """ self._log(f"[SAMPLE] target={sample_size}") import random random.seed(seed) picks = [] remaining = sample_size for p, n in zip(self.files, self.counts): if remaining <= 0: break take = min(n, max(1, int(sample_size * (n / max(1, self.total))))) if take <= 0: continue arr = self._load_features(p) if take >= n: picks.append(arr) else: idxs = random.sample(range(n), take) picks.append(arr[idxs]) remaining -= take if not picks: raise RuntimeError("No samples collected") out = torch.cat(picks, dim=0) # Shuffle if we over-sampled if out.size(0) > sample_size: perm = torch.randperm(out.size(0))[:sample_size] out = out[perm] self._log(f"[SAMPLE] actual={out.size(0)}") return out.float()
def _create_and_train_index( self, train_np: np.ndarray, nlist: int, m: int, nbits: int, use_gpu: bool, gpu_device: int, gpu_fp16_lut: bool, use_opq: bool, ) -> faiss.Index: """Create and train the FAISS index. Parameters ---------- train_np : np.ndarray Training data as a NumPy array of shape (n_samples, dim). nlist : int Number of inverted file list (IVF) partitions. m : int Number of subquantizers. nbits : int Number of bits per subvector. use_gpu : bool Whether to use GPU for training. gpu_device : int GPU device ID to use (if use_gpu is True). gpu_fp16_lut : bool Whether to use float16 lookup tables (if use_gpu is True). use_opq : bool Whether to use Optimized Product Quantization (OPQ). Returns ------- faiss.Index The trained FAISS index. """ # quantizer + metric quantizer = ( faiss.IndexFlatIP(self.dim) if self.metric == "cosine" else faiss.IndexFlatL2(self.dim) ) metric_type = ( faiss.METRIC_INNER_PRODUCT if self.metric == "cosine" else faiss.METRIC_L2 ) if use_opq: opq = faiss.OPQMatrix(self.dim, m) core = faiss.IndexIVFPQ(quantizer, self.dim, nlist, m, nbits, metric_type) cpu_template = faiss.IndexPreTransform(opq, core) else: cpu_template = faiss.IndexIVFPQ( quantizer, self.dim, nlist, m, nbits, metric_type ) if not use_gpu: cpu_template.train(train_np) return cpu_template res = faiss.StandardGpuResources() co = faiss.GpuClonerOptions() co.useFloat16 = gpu_fp16_lut gpu_index = faiss.index_cpu_to_gpu(res, gpu_device, cpu_template, co) gpu_index.train(train_np) return gpu_index
[docs] def build_index( self, index_path: str, sample_size: int = 655_360, nlist: int = 16384, m: int = 64, nbits: int = 8, use_gpu: bool = True, add_batch_size: int = 100_000, nprobe: int = 16, gpu_device: int = 0, gpu_fp16_lut: bool = True, use_opq: bool = True, tokens_dir: Optional[str] = None, compress_sequences: bool = False, ) -> faiss.Index: """Build an OPQ+IVF-PQ index. Parameters ---------- index_path : str Path to save the built index (e.g., "my_index.faiss") sample_size : int, optional Number of samples to use for training, by default 655_360 nlist : int, optional Number of inverted file list (IVF) partitions, by default 16384 m : int, optional Number of subquantizers, by default 64 nbits : int, optional Number of bits per subvector, by default 8 use_gpu : bool, optional Whether to use GPU for training, by default True add_batch_size : int, optional Batch size for adding vectors to the index, by default 100_000 nprobe : int, optional Number of probes for the IVFPQ index, by default 16 gpu_device : int, optional GPU device ID to use (if use_gpu is True), by default 0 gpu_fp16_lut : bool, optional Whether to use float16 lookup tables (if use_gpu is True), by default True use_opq : bool, optional Whether to use Optimized Product Quantization (OPQ), by default True tokens_dir : Optional[str], optional Directory containing tokenized sequences, by default None compress_sequences : bool, optional Whether to compress sequences, by default False Returns ------- faiss.Index The built FAISS index. Raises ------ RuntimeError If the index cannot be built. """ if self.dim is None or self.total == 0: raise RuntimeError("No vectors discovered under root; cannot build") self._log( f"[BUILD] start total={self.total} dim={self.dim} sample={sample_size} " f"nlist={nlist} m={m} nbits={nbits} gpu={use_gpu}" ) t0 = time.time() # Sample and prepare training data train = self.sample_vectors(sample_size) train_np = train.detach().cpu().numpy().astype("float32", copy=False) if self.metric == "cosine": faiss.normalize_L2(train_np) # Adapt nlist based on training size needed = 39 max_nlist = max(1, int(train_np.shape[0] // needed)) if nlist > max_nlist: pow2 = 1 while (pow2 << 1) <= max_nlist: pow2 <<= 1 self._log( f"[BUILD] reducing nlist {nlist} -> {pow2} (train={train_np.shape[0]})" ) nlist = max(1, pow2) index = self._create_and_train_index( train_np=train_np, nlist=nlist, m=m, nbits=nbits, use_gpu=use_gpu, gpu_device=gpu_device, gpu_fp16_lut=gpu_fp16_lut, use_opq=use_opq, ) # Set nprobe on the core IVFPQ index ivfpq_index = faiss.downcast_index( index.index if isinstance(index, faiss.IndexPreTransform) else index ) ivfpq_index.nprobe = min(int(nprobe), int(nlist)) self._log("[BUILD] Populating index with all vectors...") gid_base = 0 batch_ids = np.empty(add_batch_size, dtype=np.int64) for fi, (p, n) in tqdm(enumerate(zip(self.files, self.counts), start=1)): arr = ( self._load_features(p) .detach() .cpu() .numpy() .astype("float32", copy=False) ) if self.metric == "cosine": faiss.normalize_L2(arr) s = 0 while s < n: e = min(s + add_batch_size, n) batch_size = e - s # Use view instead of copy when possible if s == 0 and e == n: batch = arr else: batch = np.ascontiguousarray(arr[s:e], dtype=np.float32) # Reuse pre-allocated ID array batch_ids[:batch_size] = np.arange( gid_base + s, gid_base + e, dtype=np.int64 ) index.add_with_ids(batch, batch_ids[:batch_size]) s = e gid_base += n if fi % 25 == 0 or fi == len(self.files): self._log( f"[BUILD][ADD] {fi}/{len(self.files)} gid={gid_base}/{self.total} ({gid_base / self.total:.1%})" ) # Save index and manifest self._log("[BUILD] Saving index to disk...") self.save_index(index, index_path) self.save_manifest( index_path + ".manifest.json", nlist=nlist, m=m, nbits=nbits, sample_size=sample_size, ) if tokens_dir: self._log("[SEQSTORE] Building sequences.sqlite from tokens_dir...") self.build_sequence_store( index_path + ".seqs.sqlite", tokens_dir, compress=compress_sequences ) self._log("[SEQSTORE] Done") self._log(f"[BUILD] done in {time.time() - t0:.1f}s") return index
[docs] def save_index(self, index: faiss.Index, index_path: str) -> None: """Save index to disk (converts GPU to CPU if needed).""" try: cpu_index = faiss.index_gpu_to_cpu(index) except Exception: cpu_index = index t0 = time.time() faiss.write_index(cpu_index, index_path) if self.verbose: self._log(f"[SAVE] Wrote {index_path} in {time.time() - t0:.2f}s")
[docs] def save_manifest( self, path: str, *, nlist: int, m: int, nbits: int, sample_size: int ) -> None: """Save a JSON manifest with index metadata.""" data = { "version": 1, "faiss_version": getattr(faiss, "__version__", "unknown"), "metric": self.metric, "dim": int(self.dim) if self.dim is not None else None, "total": int(self.total), "ivfpq": {"nlist": nlist, "m": m, "nbits": nbits}, "opq": {"m": m}, "train": {"sample_size": sample_size}, "build_date": time.strftime("%Y-%m-%d"), } with open(path, "w") as f: json.dump(data, f) if self.verbose: self._log(f"[MANIFEST] Wrote {path}")
[docs] def build_sequence_store( self, db_path: str, tokens_dir: str, compress: bool = False, batch: int = 50_000 ) -> None: """Build a SequenceStore SQLite database from tokenized sequence files. Parameters ---------- db_path : str The path to the SQLite database file to create. tokens_dir : str The directory containing the tokenized sequence files. compress : bool, optional Whether to compress the stored sequences, by default False. batch : int, optional The number of rows to insert in each batch, by default 50_000. Raises ------ FileNotFoundError If a tokenized sequence file is missing. ValueError If there is a length mismatch between tokens and features. """ store = SequenceStore.open_writer(db_path) gid_base = 0 rows = [] for feat_path, feat_count, shard_id in zip( self.files, self.counts, self.shard_ids ): tok_path = os.path.join( tokens_dir, f"uniref50_idrs_only_{shard_id:06d}.tokens.pt" ) if not os.path.exists(tok_path): raise FileNotFoundError(f"Tokens file missing: {tok_path}") recs = torch.load(tok_path, map_location="cpu") tok_count = len(recs) if tok_count != feat_count: raise ValueError( f"Length mismatch for {tok_path}: tokens={tok_count} vs features={feat_count} (shard {shard_id:06d})" ) for local_idx, r in enumerate(recs): seq = r["sequence"] header = r.get("header") rows.append( ( gid_base + local_idx, len(seq), SequenceStore.hash8(seq), SequenceStore.encode_seq(seq, compress), shard_id, local_idx, SequenceStore.encode_header(header, compress), ) ) if len(rows) >= batch: store.insert_rows(rows) rows.clear() # advance by the features count for this shard gid_base += feat_count if rows: store.insert_rows(rows) store.close_publish() self._log(f"[SEQSTORE] Wrote {db_path} with {gid_base} entries")