Source code for starling.search.search_engine

"""
FAISS Search Engine
===================

High-performance similarity search with flexible filtering, length gating, and exact reranking.

Overview
--------
The SearchEngine provides fast ANN (Approximate Nearest Neighbor) search with:

* **Multi-level Filtering**: Embedding distance, sequence length, exact matches, identity
* **Length Gating**: Pre-filter candidates by sequence length using indexed lookups
* **Reranking**: Exact rescoring of top-k candidates using full encoder
* **Batch Processing**: Efficient handling of multiple queries
* **Flexible Metrics**: Cosine similarity or L2 distance


Basic Usage
-----------
>>> from starling.search import SearchEngine
>>> import torch
>>> engine = SearchEngine.load("my_index.faiss", metric="cosine")
>>> queries = torch.randn(10, 768)
>>> queries = torch.nn.functional.normalize(queries, dim=1)
>>> results = engine.search(queries=queries, k=100, nprobe=128, return_similarity=True)
>>> for qi, hits in enumerate(results):
...     for score, gid, header, length in hits[:5]:
...         print(qi, score, gid, length)

Advanced Usage
--------------
**Filtering by Length:**
>>> engine.search(queries, k=100, nprobe=128, length_min=50, length_max=500)

**Excluding Exact Matches:**
>>> engine.search(queries, query_sequences=["MKTLLIL..."], k=100, exclude_exact=True)

**Exact Reranking:**
>>> engine.search(queries, k=100, nprobe=128, rerank=True, rerank_device="cuda:0")

Search Parameters
-----------------
See ``search()`` docstring for the full parameter list.

Common Patterns
---------------
Pattern 1: Near duplicates (filter exact + very similar)
>>> engine.search(queries, k=100, nprobe=256, exclude_exact=True, max_cosine_similarity=0.99)

Pattern 2: Length-focused neighborhood
>>> L = 200
>>> engine.search(queries, k=1000, nprobe=128, length_min=L-50, length_max=L+50)

Pattern 3: Diverse similar sequences
>>> engine.search(queries, k=500, nprobe=128, max_cosine_similarity=0.80, length_min=50, length_max=500)

Notes
-----
* Normalize queries for cosine: ``torch.nn.functional.normalize(q, dim=1)``
* Higher ``nprobe`` improves recall at cost of latency
* Use ``rerank`` for improved precision after coarse ANN stage
"""

from __future__ import annotations

import os
from typing import Iterable, List, Optional, Sequence, Tuple

import numpy as np
import torch

from starling.search.search_utils import (
    Candidate,
    CandidateFilter,
    CosineSimFilter,
    ExactMatchFilter,
    L2DistanceFilter,
    LengthFilter,
    ScoreConverter,
    SequenceIdentityFilter,
    ValidGidFilter,
)
from starling.search.store import SequenceStore

try:
    import faiss
except Exception as e:
    raise ImportError(
        "FAISS is required. Please install a CPU or GPU compatible version"
    ) from e


[docs] class SearchEngine: """FAISS-backed similarity search with rich post-filtering. The :class:`SearchEngine` wraps a trained ``faiss.Index`` and optional :class:`~starling.search.store.SequenceStore` to provide high-level search primitives used across STARLING. In addition to ANN lookups, it supports length gating, identity thresholds, exact-match suppression, and optional reranking with the encoder. Parameters ---------- index : faiss.Index Trained FAISS index (e.g., IVFPQ, HNSW) matching the embedding metric. metric : {'cosine', 'l2'}, default 'cosine' Similarity/distance metric used during index construction. seq_store : SequenceStore, optional Sequence/metadata store generated alongside the index. Required for filters that need sequence lengths or raw sequences, and for reranking. verbose : bool, default True When ``True`` the engine logs key operations (loading, filters, rerank). Attributes ---------- index : faiss.Index Underlying FAISS index used for coarse ANN search. metric : str Metric string supplied during construction (``"cosine"`` or ``"l2"``). seq_store : SequenceStore or None Attached sequence store used for metadata-aware filtering. total : int Number of vectors stored in the index. dim : int Embedding dimensionality of the index. Examples -------- >>> from starling.search import SearchEngine >>> engine = SearchEngine.load("/data/index.faiss") >>> hits = engine.search(queries, k=50, nprobe=128, rerank=True) >>> len(hits[0]) 50 """
[docs] def __init__( self, index: faiss.Index, metric: str = "cosine", seq_store: Optional[SequenceStore] = None, verbose: bool = True, ): if metric not in {"l2", "cosine"}: raise ValueError("metric must be 'l2' or 'cosine'") self.index = index self.metric = metric self.seq_store = seq_store self.verbose = verbose self._log = print if verbose else (lambda *a, **k: None) self.total = int(index.ntotal) self.dim = int(index.d)
[docs] @classmethod def load( cls, index_path: str, metric: str = "cosine", verbose: bool = True, ) -> "SearchEngine": """Load a serialized FAISS index (and optional sequence store). This attempts to read ``index_path`` (a .faiss file) and, if present, a sibling ``.seqs.sqlite`` file with the same base name. The sequence store is optional; absence only disables filters / reranking that need it. Parameters ---------- index_path : str Path to the serialized FAISS index (typically ends with ``.faiss``). metric : {'cosine', 'l2'}, default 'cosine' Metric that will be used for queries. Must match how embeddings were prepared. verbose : bool, default True Whether to emit log messages during loading. Returns ------- SearchEngine A ready-to-query search engine instance. """ index = faiss.read_index(index_path) # Try to load sequence store if available seq_store = None seq_db = index_path + ".seqs.sqlite" if os.path.exists(seq_db): try: seq_store = SequenceStore.open_reader(seq_db) if verbose: print(f"[SEQSTORE] Attached {seq_db}") except Exception as e: if verbose: print(f"[SEQSTORE] Failed to open {seq_db}: {e}") return cls(index=index, metric=metric, seq_store=seq_store, verbose=verbose)
def _compute_fetch_k( self, k: int, overfetch: Optional[int], any_filters: bool, ) -> int: """Compute internal candidate fetch size prior to filtering. Overfetching helps preserve recall when downstream filters prune results. Parameters ---------- k : int Desired number of final results per query. overfetch : int, optional User-provided multiplicative factor. If ``None``, it is auto-derived. any_filters : bool Indicator that at least one filter (besides ValidGid) is active. Returns ------- int Number of candidates to request from FAISS per query. """ if overfetch is None: overfetch = 5 if any_filters else 1 fetch_k = k * max(1, overfetch) return fetch_k def _build_filters( self, return_similarity: bool, exclude_exact: bool, sequence_identity_max: Optional[float], identity_denominator: str, min_l2_distance: Optional[float], max_cosine_similarity: Optional[float], length_min: Optional[int], length_max: Optional[int], ) -> List[CandidateFilter]: """Instantiate the base (query-agnostic) filter objects. Sequence-dependent filters (exact match / identity) are created per-query later because they depend on the query sequence content. Parameters ---------- return_similarity : bool Whether outward facing scores are similarities (True) or distances. exclude_exact : bool If True, exact sequence matches (by hash + content) will be removed. sequence_identity_max : float, optional Maximum allowed identity; candidates exceeding are filtered. identity_denominator : {'query','target','max','min','avg'} Identity denominator mode (passed along to the identity function). min_l2_distance : float, optional Minimum allowed L2 distance (applies only for ``metric='l2'``). max_cosine_similarity : float, optional Maximum allowed cosine similarity (applies only for ``metric='cosine'``). length_min : int, optional Minimum sequence length (inclusive). length_max : int, optional Maximum sequence length (inclusive). Returns ------- list of CandidateFilter Base filters (ValidGid + optional embedding/length filters). """ filters: List[CandidateFilter] = [ValidGidFilter()] # Embedding-level filters if self.metric == "l2" and min_l2_distance is not None: filters.append(L2DistanceFilter(min_l2_distance)) if self.metric == "cosine" and max_cosine_similarity is not None: filters.append(CosineSimFilter(max_cosine_similarity, return_similarity)) # Length filter if length_min is not None or length_max is not None: filters.append(LengthFilter(length_min, length_max)) # Sequence-based filters will be instantiated per-query return filters def _set_nprobe(self, nprobe: Optional[int]) -> None: """Safely set ``nprobe`` on the underlying IVF index if available. Parameters ---------- nprobe : int, optional Number of IVF lists to probe. ``None`` leaves index default unchanged. Notes ----- * Silently no-ops if the index has no IVF component. * Removes any active ``max_codes`` cap (sets to 0 / unlimited). """ if nprobe is None: return # Try to set on inner IVF index (works for IndexPreTransform too) inner = None try: core = self.index.index if hasattr(self.index, "index") else self.index inner = faiss.downcast_index(core) except Exception as e: self._log(f"[WARN] Failed to downcast index: {e}") if inner is not None and hasattr(inner, "nprobe"): inner.nprobe = int(nprobe) # Remove scan caps if present if hasattr(inner, "max_codes"): try: inner.max_codes = 0 # 0 => unlimited except Exception as e: self._log(f"[WARN] Failed to set max_codes: {e}") self._log( "[SEARCH] Set nprobe={} nlist={} max_codes={}".format( getattr(inner, "nprobe", None), getattr(inner, "nlist", None), getattr(inner, "max_codes", None), ) ) elif hasattr(self.index, "nprobe"): # Fallback for non-wrapped IVF self.index.nprobe = int(nprobe) self._log(f"[SEARCH] Set nprobe={nprobe} (direct)") else: self._log("[WARN] Could not set nprobe, index has no nprobe attribute") def _ann_search( self, queries_np: np.ndarray, fetch_k: int, nprobe: Optional[int], selector_ids: Optional[Sequence[int]] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Execute ANN search against FAISS with optional ID restriction. Parameters ---------- queries_np : ndarray, shape (Q, D) Query vectors (float32) already normalized if cosine metric. fetch_k : int Number of candidates to request from FAISS per query. nprobe : int, optional IVF probe count override. selector_ids : sequence of int, optional If provided, attempt to restrict search to these IDs (uses IDSelectorBatch). Returns ------- D : ndarray, shape (Q, fetch_k) Raw FAISS scores (inner products or squared distances). idxs : ndarray, shape (Q, fetch_k) Candidate global IDs (``-1`` for invalid slots). Notes ----- Falls back to standard search if selector path fails or unsupported. """ self._set_nprobe(nprobe) if selector_ids is not None: present = sum(1 for g in selector_ids if 0 <= int(g) < int(self.total)) self._log( f"[SEARCH] Using selector: {present}/{len(selector_ids)} IDs in range " f"(total={self.total})" ) try: params = faiss.SearchParametersIVF() if nprobe is not None: params.nprobe = int(nprobe) sel = faiss.IDSelectorBatch(np.asarray(selector_ids, dtype=np.int64)) params.sel = sel # note: you *must* use params as a kwarg otherwise faiss python wrappers fail # and you experience pain.... D, idxs = self.index.search(queries_np, fetch_k, params=params) self._log("[SEARCH] IDSelectorBatch search completed successfully.") return D, idxs except Exception as e: self._log( f"[WARN] IDSelectorBatch failed: {e}, falling back to standard search" ) self._log("[SEARCH] Using standard search (no selector).") D, idxs = self.index.search(queries_np, fetch_k) return D, idxs def _collect_meta(self, idxs: np.ndarray) -> dict: """Fetch metadata for unique candidate IDs present in result indices. Parameters ---------- idxs : ndarray, shape (Q, K) Candidate ID matrix (``-1`` entries ignored). Returns ------- dict Mapping ``gid -> (header, length, hash8)``. Empty if no sequence store. """ gids = {int(g) for g in idxs.flatten() if g >= 0} if not (self.seq_store and gids): return {} fetched = self.seq_store.get_many_meta(gids) return {gid: (hdr, length, h8) for gid, hdr, length, h8 in fetched} def _seq_identity(self, seq1: str, seq2: str, denom: str = "query") -> float: """Compute approximate (ungapped) sequence identity. This is a fast heuristic: counts exact character matches over the overlapping prefix (``min(len(seq1), len(seq2))``) and divides by a denominator selected by ``denom``. Any overhang beyond the overlap is implicitly treated as mismatched when the denominator is larger than the overlap (e.g. ``max`` / ``avg`` cases). Parameters ---------- seq1 : str First (query) sequence. seq2 : str Second (target) sequence. denom : {'query','target','max','min','avg'}, default 'query' Denominator policy: * query -> len(seq1) * target -> len(seq2) * max -> max(len1, len2) * min -> min(len1, len2) * avg -> 0.5 * (len1 + len2) Returns ------- float Identity fraction in [0,1]. Returns 0.0 if any sequence is empty or denominator resolves to 0. Notes ----- This does NOT perform alignment (no gap handling). For short indels the value may underestimate true identity. Suitable for coarse filtering. """ if not seq1 or not seq2: return 0.0 if seq1 == seq2: return 1.0 l1 = len(seq1) l2 = len(seq2) overlap = min(l1, l2) matches = 0 # Fast loop over overlap for a, b in zip(seq1[:overlap], seq2[:overlap]): if a == b: matches += 1 if denom == "query": denom_len = l1 elif denom == "target": denom_len = l2 elif denom == "max": denom_len = max(l1, l2) elif denom == "min": denom_len = overlap # == min(l1, l2) elif denom == "avg": denom_len = 0.5 * (l1 + l2) else: # Fallback to query length denom_len = l1 if denom_len <= 0: return 0.0 return matches / float(denom_len) def _filter_candidates( self, D: np.ndarray, idxs: np.ndarray, k: int, converter: ScoreConverter, query_sequences: Optional[List[str]], exclude_exact_sequence: bool, sequence_identity_max: Optional[float], identity_denominator: str, base_filters: List[CandidateFilter], meta_map: dict, ) -> Tuple[List[List[Tuple[float, int, Optional[str], Optional[int]]]], set]: """Apply filters for all queries and build preliminary top-k lists. Parameters ---------- D : ndarray, shape (Q, F) Raw FAISS scores. idxs : ndarray, shape (Q, F) Candidate IDs (aligned with ``D``), ``-1`` marks empty slots. k : int Final desired results per query. converter : ScoreConverter Handles score conversion (cosine IP -> similarity / distance, etc.). query_sequences : list of str, optional Raw query sequences (needed for exact/identity filters). exclude_exact : bool If True, remove exact (sequence-equal) matches. sequence_identity_max : float, optional Identity threshold; above this candidates are removed. identity_denominator : str Denominator mode passed to identity filter. base_filters : list of CandidateFilter Filters that are query-agnostic. meta_map : dict Metadata map from ``gid`` to (header, length, hash8). Returns ------- preliminary : list of list of tuple Per-query list of (score, gid, header, length) tuples after filtering. rerank_gid_set : set of int Unique GIDs that survived at least one query (union for reranking). """ Q = D.shape[0] # Pre-compute query hashes if needed query_hashes = {} if exclude_exact_sequence and query_sequences and self.seq_store: query_hashes = { i: self.seq_store.hash8(s) for i, s in enumerate(query_sequences) } out: List[List[Tuple[float, int, Optional[str], Optional[int]]]] = [] rerank_gid_set: set = set() for qi in range(Q): # Build per-query filters (includes sequence-specific filters) filters = list(base_filters) qseq = query_sequences[qi] if query_sequences is not None else None if ( exclude_exact_sequence and qseq and self.seq_store and qi in query_hashes ): filters.append(ExactMatchFilter(query_hashes[qi], self.seq_store)) if sequence_identity_max is not None and qseq and self.seq_store: filters.append( SequenceIdentityFilter( sequence_identity_max, identity_denominator, self.seq_store, self._seq_identity, ) ) # Filter candidates for this query row, filter_counts = self._filter_query_candidates( D[qi], idxs[qi], k, converter, qseq, filters, meta_map ) self._log( f"[FILTER][q{qi}] total={filter_counts['total']} kept={len(row)} | " + " ".join( f"filt_{k}={v}" for k, v in filter_counts.items() if k != "total" ) ) out.append(row) for _, gid, _, _ in row: rerank_gid_set.add(gid) return out, rerank_gid_set def _filter_query_candidates( self, scores: np.ndarray, gids: np.ndarray, k: int, converter: ScoreConverter, query_seq: Optional[str], filters: List[CandidateFilter], meta_map: dict, ) -> Tuple[List[Tuple[float, int, Optional[str], Optional[int]]], dict]: """Filter candidates for a single query row. Parameters ---------- scores : ndarray, shape (F,) Raw FAISS scores for one query. gids : ndarray, shape (F,) Candidate IDs aligned with ``scores``. k : int Desired number of outputs (early stop once reached). converter : ScoreConverter Converter to transform raw score to outward facing metric. query_seq : str, optional Raw query sequence (for exact / identity filters); may be None. filters : list of CandidateFilter Active filters to apply in order. meta_map : dict Mapping gid -> (header, length, hash8) metadata. Returns ------- row : list of tuple Filtered (score, gid, header, length) tuples (<= k items). filter_counts : dict Per-filter rejection counts plus total candidates examined. """ row: List[Tuple[float, int, Optional[str], Optional[int]]] = [] filter_counts = {f.get_name(): 0 for f in filters} filter_counts["total"] = 0 for raw_score, gid in zip(scores, gids): filter_counts["total"] += 1 # Build candidate header, length, stored_hash = meta_map.get(int(gid), (None, None, None)) candidate = Candidate( score=converter.convert(raw_score), gid=int(gid), header=header, length=length, stored_hash=stored_hash, ) # Apply filters passed = True for f in filters: if not f.apply(candidate, query_seq): filter_counts[f.get_name()] += 1 passed = False break if passed: row.append(candidate.as_tuple()) if len(row) >= k: break return row, filter_counts def _encode_sequences_for_rerank( self, gids: Iterable[int], device: Optional[str], batch_size: int, ionic_strength: Optional[int], ) -> dict: """Encode candidate sequences for exact reranking. Parameters ---------- gids : iterable of int Unique candidate GIDs to materialize/encode. device : str, optional Device string passed to encoder (e.g., 'cuda:0', 'cpu'). Auto if None. batch_size : int Encoder batch size. ionic_strength : int, optional Ionic strength forwarded to the encoder (controls deterministic dropout etc.). Returns ------- dict Mapping ``gid -> embedding (torch.Tensor, shape (D,))``. Empty if no seq store. """ if not self.seq_store: return {} seq_map = {} requested = 0 have_sequences = 0 for gid in gids: requested += 1 s = self.seq_store.get_seq(gid) if s is not None: have_sequences += 1 seq_map[f"g{gid}"] = s if not seq_map: self._log( f"[RERANK] requested_gids={requested} have_sequences=0 encoded_vecs=0" ) return {} from starling.inference.generation import sequence_encoder_backend chosen_device = device or ("cuda:0" if torch.cuda.is_available() else "cpu") ionic_strength = 150 if ionic_strength is None else ionic_strength emb_dict = sequence_encoder_backend( sequence_dict=seq_map, device=chosen_device, batch_size=batch_size, ionic_strength=ionic_strength, aggregate=True, output_directory=None, ) # convert to gid -> tensor out: dict[int, torch.Tensor] = {} for k, v in emb_dict.items(): gid = int(k[1:]) # strip leading 'g' t = v if isinstance(v, torch.Tensor) else torch.as_tensor(v) out[gid] = t.float() # normalize if cosine metric if self.metric == "cosine": for gid, t in out.items(): out[gid] = torch.nn.functional.normalize(t.unsqueeze(0), dim=1).squeeze( 0 ) self._log( f"[RERANK] requested_gids={requested} have_sequences={have_sequences} encoded_vecs={len(out)}" ) return out def _rerank( self, preliminary: List[List[Tuple[float, int, Optional[str], Optional[int]]]], queries: torch.Tensor, gid_to_vec: dict, k: int, converter: ScoreConverter, ) -> List[List[Tuple[float, int, Optional[str], Optional[int]]]]: """Rerank preliminary candidates using freshly encoded exact vectors. Parameters ---------- preliminary : list of list Per-query preliminary candidate tuples (score, gid, header, length). queries : torch.Tensor, shape (Q, D) Original query embeddings (pre-normalized if cosine). gid_to_vec : dict Mapping gid -> candidate embedding for reranking. k : int Final number of results to keep. converter : ScoreConverter Score conversion utility (determines ascending/descending ordering logic). Returns ------- list of list Per-query reranked list of (score, gid, header, length) tuples. """ if not gid_to_vec: self._log( "[RERANK] No vectors available for rerank; returning prelim top-k" ) return [row[:k] for row in preliminary] q_emb = queries.detach().cpu().float() if self.metric == "cosine": q_emb = torch.nn.functional.normalize(q_emb, dim=1) final: List[List[Tuple[float, int, Optional[str], Optional[int]]]] = [] for qi, row in enumerate(preliminary): if not row: final.append([]) self._log(f"[RERANK][q{qi}] prelim=0 have_vecs=0 final=0") continue # Find candidates with available vectors cand = [gid for _, gid, _, _ in row if gid in gid_to_vec] if not cand: final.append(row[:k]) self._log( f"[RERANK][q{qi}] prelim={len(row)} have_vecs=0 final={min(len(row), k)}" ) continue # Compute exact scores mat = torch.stack([gid_to_vec[g] for g in cand]) # (N,D) qv = q_emb[qi] if self.metric == "cosine": sims = (qv.unsqueeze(0) @ mat.T).squeeze(0) raw_scores = sims if converter.return_similarity else (1.0 - sims) desc = converter.return_similarity else: # l2 distance diffs = mat - qv raw_scores = torch.sqrt(torch.sum(diffs * diffs, dim=1)) desc = False # Sort and build result order = torch.argsort(raw_scores, descending=desc) meta_lookup = {gid: (hdr, ln) for _, gid, hdr, ln in row} new_row: List[Tuple[float, int, Optional[str], Optional[int]]] = [] for oi in order[:k]: gid = cand[int(oi)] hdr, ln = meta_lookup.get(gid, (None, None)) new_row.append((float(raw_scores[int(oi)].item()), gid, hdr, ln)) final.append(new_row) self._log( f"[RERANK][q{qi}] prelim={len(row)} have_vecs={len(cand)} final={len(new_row)}" ) return final
[docs] def search( self, queries: torch.Tensor, k: int = 10, nprobe: Optional[int] = None, return_similarity: bool = False, query_sequences: Optional[List[str]] = None, exclude_exact: bool = False, sequence_identity_max: Optional[float] = None, identity_denominator: str = "query", min_l2_distance: Optional[float] = None, max_cosine_similarity: Optional[float] = None, overfetch: Optional[int] = None, rerank: bool = False, rerank_device: Optional[str] = None, rerank_batch_size: int = 64, rerank_ionic_strength: Optional[int] = None, length_min: Optional[int] = None, length_max: Optional[int] = None, ) -> List[List[Tuple[float, int, Optional[str], Optional[int]]]]: """Approximate nearest neighbor search with optional filtering and reranking. Parameters ---------- queries : torch.Tensor, shape (Q, D) Query embeddings. For cosine metric they should be L2-normalized externally (the method re-normalizes defensively before search). k : int, default 10 Number of final results per query. nprobe : int, optional IVF probe count override; higher => better recall, slower. return_similarity : bool, default False If True (cosine), return similarity in [0,1]; otherwise distance. query_sequences : list of str, optional Raw sequences aligned with queries (enables exact / identity filters & rerank). exclude_exact : bool, default False Remove candidates whose sequence matches the query exactly. sequence_identity_max : float, optional Maximum allowed identity (0-1). Requires ``query_sequences`` and seq store. identity_denominator : {'query','target','max','min','avg'}, default 'query' Mode for identity normalization (forwarded to identity function). min_l2_distance : float, optional Minimum allowed L2 distance (metric='l2'). max_cosine_similarity : float, optional Maximum allowed cosine similarity (metric='cosine'). overfetch : int, optional Multiply ``k`` by this before filtering. Auto: 5 if filters active else 1. rerank : bool, default False If True, re-embed surviving candidates (exact scoring) and reorder. rerank_device : str, optional Device for reranking encoder; auto-select if None. rerank_batch_size : int, default 64 Batch size for rerank embedding pass. rerank_ionic_strength : int, optional Ionic strength forwarded to encoder (if None uses model default). length_min : int, optional Minimum acceptable candidate length. length_max : int, optional Maximum acceptable candidate length. Returns ------- list of list of tuple Per-query results: ``[(score, gid, header, length), ...]`` up to ``k``. Raises ------ RuntimeError If filters requiring a sequence store are requested but none is attached. """ # Validate requirements seq_filters_active = exclude_exact or sequence_identity_max is not None if ( seq_filters_active or rerank or length_min is not None or length_max is not None ) and self.seq_store is None: raise RuntimeError( "SequenceStore required for filters/length gating/rerank" ) # Prepare queries q_np = queries.detach().cpu().float().numpy() if self.metric == "cosine": faiss.normalize_L2(q_np) # Setup score converter and filters converter = ScoreConverter(self.metric, return_similarity) base_filters = self._build_filters( return_similarity, exclude_exact, sequence_identity_max, identity_denominator, min_l2_distance, max_cosine_similarity, length_min, length_max, ) # Compute fetch_k any_filters = ( len(base_filters) > 1 or exclude_exact or sequence_identity_max is not None ) fetch_k = self._compute_fetch_k(k, overfetch, any_filters) self._log( f"[SEARCH] k={k} fetch_k={fetch_k} nprobe={nprobe} queries={q_np.shape[0]}" ) # ANN search with optional length selector selector_ids = None if self.seq_store and (length_min is not None or length_max is not None): selector_ids = self.seq_store.get_gids_by_length_range( length_min, length_max ) D, idxs = self._ann_search(q_np, fetch_k, nprobe, selector_ids=selector_ids) # Collect metadata and filter meta_map = self._collect_meta(idxs) preliminary, rerank_gid_set = self._filter_candidates( D=D, idxs=idxs, k=k, converter=converter, query_sequences=query_sequences, exclude_exact_sequence=exclude_exact, sequence_identity_max=sequence_identity_max, identity_denominator=identity_denominator, base_filters=base_filters, meta_map=meta_map, ) if not rerank: return [row[:k] for row in preliminary] total_prelim = sum(len(r) for r in preliminary) self._log( f"[RERANK] total_prelim={total_prelim} unique_candidates={len(rerank_gid_set)}" ) gid_to_vec = self._encode_sequences_for_rerank( gids=rerank_gid_set, device=rerank_device, batch_size=rerank_batch_size, ionic_strength=rerank_ionic_strength, ) return self._rerank( preliminary=preliminary, queries=queries, gid_to_vec=gid_to_vec, k=k, converter=converter, )