Source code for starling.data.tokenizer

[docs] class StarlingTokenizer: """Lightweight amino-acid tokenizer used across STARLING. The tokenizer exposes fast byte-level ``encode``/``decode`` helpers that map between protein sequences and integer vocab IDs. It is optimized for bulk processing inside the sequence encoder and FAISS tokenization pipeline. Examples -------- >>> from starling.data.tokenizer import StarlingTokenizer >>> tok = StarlingTokenizer() >>> ids = tok.encode("ACDE") >>> ids [1, 2, 3, 4] >>> tok.decode(ids) 'ACDE' Notes ----- * Unknown characters raise :class:`KeyError` during encoding. * Padding/reserved ``0`` tokens are stripped during decode. """ # Public vocab maps (kept for compatibility with existing code) aa_to_int = { "0": 0, "A": 1, "C": 2, "D": 3, "E": 4, "F": 5, "G": 6, "H": 7, "I": 8, "K": 9, "L": 10, "M": 11, "N": 12, "P": 13, "Q": 14, "R": 15, "S": 16, "T": 17, "V": 18, "W": 19, "Y": 20, } int_to_aa = {v: k for k, v in aa_to_int.items()}
[docs] def __init__(self): # Build encode translation table (byte -> id). 255 = sentinel for "unknown". table = bytearray([255] * 256) for ch, idx in self.aa_to_int.items(): table[ord(ch)] = idx self._enc_table = bytes(table) # Build decode translation table (id -> ASCII byte), unknown -> '?' inv = bytearray([ord("?")] * 256) for idx, ch in self.int_to_aa.items(): inv[idx] = ord(ch) self._dec_table = bytes(inv)
[docs] def encode(self, sequence: str): """ Encode a string of amino acids into a list[int]. Raises ValueError if an unknown character is present. """ b = sequence.encode("ascii") # bytes of ids (0..255; 255 means unknown) enc = b.translate(self._enc_table) if 255 in enc: for i, bt in enumerate(b): if self._enc_table[bt] == 255: raise KeyError(f"Unknown token '{chr(bt)}' at position {i}") # Convert to Python ints return list(enc)
[docs] def decode(self, sequence): """ Decode a sequence of ints into a string, dropping zeros (padding). Accepts list/tuple[int], bytes, or bytearray. """ # Normalize to raw bytes of ids if isinstance(sequence, (bytes, bytearray)): buf = sequence else: buf = bytes(int(x) & 0xFF for x in sequence) # Drop 0 tokens before reverse translate filtered = bytes(x for x in buf if x != 0) ascii_bytes = filtered.translate(self._dec_table) # -> bytes of alphabet return ascii_bytes.decode("ascii").replace("?", "")