diff --git a/.github/workflows/build_jar.yml b/.github/workflows/build_jar.yml new file mode 100644 index 00000000..4b0d4476 --- /dev/null +++ b/.github/workflows/build_jar.yml @@ -0,0 +1,81 @@ +name: Build Java JAR + +on: [push, pull_request, workflow_dispatch] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build_jni: + name: jni on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + include: + - os: ubuntu-latest + outdir: linux_64 + - os: windows-latest + outdir: windows_64 + - os: macos-latest + outdir: osx_64 + steps: + - uses: actions/checkout@v3 + + - name: Install rust toolchain + uses: actions-rs/toolchain@v1 + with: + # stable doesn't have --out-dir + toolchain: nightly + override: true + + - name: Build + working-directory: ./jni + # TODO: 32bit vs 64bit? + # https://github.com/scijava/native-lib-loader + run: cargo build --release -Z unstable-options --out-dir ../build/natives/${{ matrix.outdir }}/ + + - uses: actions/upload-artifact@v3 + with: + name: natives + path: ./build/natives/* + + build_java: + name: java + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + needs: [build_jni] + + steps: + - uses: actions/checkout@v3 + + - name: Load outputs + uses: actions/download-artifact@v3 + with: + name: natives + path: natives + + - name: Set up JDK 11 + uses: actions/setup-java@v3 + with: + java-version: '11' + distribution: 'microsoft' + architecture: x64 + cache: maven + + - name: Build with Maven + working-directory: ./java + run: mvn --batch-mode package failsafe:integration-test + + - uses: actions/upload-artifact@v3 + with: + name: java + path: ./java/target/*.jar + + # TODO: publish to maven (only from ubuntu) + diff --git a/Cargo.toml b/Cargo.toml index 1fb806bc..bcc4bb52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,21 +1,15 @@ -[package] -name = "tiktoken" -version = "0.2.0" -edition = "2021" -rust-version = "1.57.0" +[workspace] -[lib] -name = "_tiktoken" -crate-type = ["cdylib"] - -[dependencies] -pyo3 = { version = "0.17.3", features = ["extension-module"] } - -# tiktoken dependencies -fancy-regex = "0.10.0" -regex = "1.7.0" -rustc-hash = "1.1.0" -bstr = "1.0.1" +members = [ + "core", + "python", + "jni", +] [profile.release] incremental = true +opt-level = 'z' # Optimize for size +lto = true # Enable link-time optimization +codegen-units = 1 # Reduce number of codegen units to increase optimizations +panic = 'abort' # Abort on panic +strip = true # Strip symbols from binary* \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 7f25b271..321b66e2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,3 +6,4 @@ global-include py.typed recursive-include scripts *.py recursive-include tests *.py recursive-include src *.rs +include tiktoken *.json \ No newline at end of file diff --git a/core/Cargo.toml b/core/Cargo.toml new file mode 100644 index 00000000..53688fd4 --- /dev/null +++ b/core/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tiktoken_core" +version = "0.2.0" +edition = "2021" +rust-version = "1.57.0" + +[lib] +name = "_tiktoken_core" +crate-type = ["lib"] + +[dependencies] +# tiktoken dependencies +fancy-regex = "0.10.0" +regex = "1.7.0" +rustc-hash = "1.1.0" +bstr = "1.0.1" +reqwest = { version = "0.11.14", features = ["blocking"] } +sha1 = "0.10.5" +json = "0.12.4" +base64 = "0.21.0" +lazy_static = "1.4.0" diff --git a/src/lib.rs b/core/src/lib.rs similarity index 63% rename from src/lib.rs rename to core/src/lib.rs index b44d9c8b..52cacb04 100644 --- a/src/lib.rs +++ b/core/src/lib.rs @@ -1,124 +1,17 @@ -// This check is new and seems buggy (possibly with PyO3 interaction) -#![allow(clippy::borrow_deref_ref)] - use std::collections::HashSet; use std::thread; use fancy_regex::Regex; -use pyo3::exceptions; -use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyList, PyTuple}; -use pyo3::PyResult; use rustc_hash::FxHashMap as HashMap; -fn _byte_pair_merge( - piece: &[u8], - ranks: &HashMap, usize>, - f: impl Fn(std::ops::Range) -> T, -) -> Vec { - // This is a vector of (start, rank). - // The rank is of the byte pair starting at position start. - // The rank of the last item in the vector is not a valid value. - let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); - - // NOTE: using a macro here because a closure fails to get inlined - // according to optimization remarks. - // A closure also cannot capture a reference to `piece` without - // the borrow checker complaining about the mutable borrows during - // the assignments later in this code. - macro_rules! get_rank { - ($start_idx:expr, $skip:expr) => {{ - let start_idx: usize = $start_idx; - let skip: usize = $skip; - if (start_idx + skip + 2) < parts.len() { - ranks - .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) - .map(|r| *r) - } else { - None - } - }}; - ($idx:expr) => {{ - get_rank!($idx, 0) - }}; - } +mod util; +mod load; +pub mod openai_public; - // We look up the ranks once in the beggining and iteratively update - // them during each merge, which reduces the number of rank lookups. - for i in 0..parts.len() - 2 { - match get_rank!(i) { - Some(rank) => { - // usize::MAX is a sentinel value and cannot be a valid rank - debug_assert!(rank != usize::MAX); - parts[i].1 = rank; - } - None => { - continue; - } - }; - } +#[macro_use] +extern crate lazy_static; - // If you have n parts and m merges, this does O(mn) work. - // We could do something with a heap and do O(m log n) work. - // It is important to consider that n is often small (<100), and as such - // the cache-locality benefits outweigh the algorithmic complexity downsides - // of the `parts` vector data structure above. - - // Note that we hash bytes, not token pairs. As long as we train BPE the way we - // currently do, this is equivalent. An easy way to break this would be to decouple - // merge priority from token index or to prevent specific token merges. - loop { - if parts.len() == 1 { - break; - } - - // usize::MAX is a sentinel rank value allowing us to - // take the min more quickly - let mut min_rank: (usize, usize) = (usize::MAX, 0); - for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { - if rank < min_rank.0 { - min_rank = (rank, i); - } - } - - if min_rank.0 != usize::MAX { - let i = min_rank.1; - - // NOTE: We are about to remove parts[i + 1]. We do not do it - // yet because there are cache-locality benefits to updating - // parts[i] and parts[i-1] before removing, which could thrash - // the cache. Thus, we update the rank calculation by skipping over - // parts[i + 1], by invoking `get_rank!` with `skip = 1`. - parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX); - if i > 0 { - parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX); - } - - parts.remove(i + 1); - } else { - break; - } - } - let mut out: Vec = Vec::with_capacity(parts.len() - 1); - for i in 0..parts.len() - 1 { - out.push(f(parts[i].0..parts[i + 1].0)); - } - out -} - -pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { - if piece.len() == 1 { - return vec![ranks[piece]]; - } - _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) -} - -pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { - if piece.len() == 1 { - return vec![piece]; - } - _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) -} +const MAX_NUM_THREADS: usize = 128; // Various performance notes: // @@ -177,9 +70,7 @@ fn hash_current_thread() -> usize { u64::from(x) as usize } -const MAX_NUM_THREADS: usize = 128; -#[pyclass] -struct CoreBPE { +pub struct CoreBPENative { encoder: HashMap, usize>, special_tokens_encoder: HashMap, decoder: HashMap>, @@ -189,7 +80,7 @@ struct CoreBPE { sorted_token_bytes: Vec>, } -impl CoreBPE { +impl CoreBPENative { fn _get_tl_regex(&self) -> &Regex { // See performance notes above for what this is about // It's also a little janky, please make a better version of it! @@ -201,7 +92,7 @@ impl CoreBPE { &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] } - fn _decode_native(&self, tokens: &[usize]) -> Vec { + pub fn _decode_native(&self, tokens: &[usize]) -> Vec { let mut ret = Vec::with_capacity(tokens.len() * 2); for token in tokens { let token_bytes = self @@ -213,7 +104,7 @@ impl CoreBPE { ret } - fn _encode_ordinary_native(&self, text: &str) -> Vec { + pub fn _encode_ordinary_native(&self, text: &str) -> Vec { // This is the core of the encoding logic; the other functions in here // just make things complicated :-) let regex = self._get_tl_regex(); @@ -224,12 +115,13 @@ impl CoreBPE { ret.push(*token); continue; } - ret.extend(&byte_pair_encode(piece, &self.encoder)); + ret.extend(&util::byte_pair_encode(piece, &self.encoder)); } ret } - fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + pub fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>, max_tokens: Option) -> (Vec, usize, usize) { + let max_tokens = max_tokens.unwrap_or(usize::MAX); let special_regex = self._get_tl_special_regex(); let regex = self._get_tl_regex(); let mut ret = vec![]; @@ -260,11 +152,20 @@ impl CoreBPE { if let Some(token) = self.encoder.get(piece) { last_piece_token_len = 1; ret.push(*token); + + if ret.len() >= max_tokens { + return (ret, last_piece_token_len, start); + } continue; } - let tokens = byte_pair_encode(piece, &self.encoder); + let tokens = util::byte_pair_encode(piece, &self.encoder); last_piece_token_len = tokens.len(); - ret.extend(&tokens); + for token in tokens { + ret.push(token); + if ret.len() >= max_tokens { + return (ret, last_piece_token_len, start); + } + } } match next_special { @@ -273,8 +174,12 @@ impl CoreBPE { let piece = m.as_str(); let token = self.special_tokens_encoder[piece]; ret.push(token); + start = m.end(); last_piece_token_len = 0; + if ret.len() >= max_tokens { + return (ret, last_piece_token_len, start); + } } None => break, } @@ -282,7 +187,32 @@ impl CoreBPE { // last_piece_token_len is how many tokens came from the last regex split. This is used // for determining unstable tokens, since you can't merge across (stable) regex splits - (ret, last_piece_token_len) + (ret, last_piece_token_len, start) + } + + pub fn _encode_bytes(&self, bytes: &[u8]) -> Vec { + match std::str::from_utf8(bytes) { + Ok(text) => self._encode_ordinary_native(text), + Err(e) => { + let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; + let (tokens, last_piece_token_len, _) = self._encode_native(text, &HashSet::new(), None); + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + if !tokens.is_empty() && last_piece_token_len > 0 { + // Lop off the tokens from the last piece and run BPE on the remaining bytes + // Somewhat niche, but this may not be correct if we'd have had a regex + // split between the valid UTF-8 and the invalid bytes, which is why this + // method is private + let mut unstable_bytes = + self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); + + tokens.truncate(tokens.len() - last_piece_token_len); + tokens.extend(util::byte_pair_encode(&unstable_bytes, &self.encoder)); + } + tokens + } + } } fn _increase_last_piece_token_len( @@ -324,12 +254,12 @@ impl CoreBPE { (tokens, last_piece_token_len) } - fn _encode_unstable_native( + pub fn _encode_unstable_native( &self, text: &str, allowed_special: &HashSet<&str>, ) -> (Vec, HashSet>) { - let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); + let (tokens, last_piece_token_len, _) = self._encode_native(text, allowed_special, None); if last_piece_token_len == 0 { // If last_piece_token_len is zero, the last token was a special token and we have // no unstable bytes @@ -392,7 +322,7 @@ impl CoreBPE { // would be a regex split before the UTF-8 truncation point. // Probably niche enough that no one will ever notice (after all, people didn't // notice all the big holes in the previous unstable token implementation) - Err(_) => byte_pair_encode(&possibility, &self.encoder), + Err(_) => util::byte_pair_encode(&possibility, &self.encoder), // Something like the following is intriguing but incorrect: // Err(e) => self._encode_ordinary_native(unsafe { // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) @@ -425,11 +355,11 @@ impl CoreBPE { if unstable_bytes.len() - last_decoded.1 > 0 && last_decoded.0.map_or(false, |c| c.is_whitespace()) { - let mut reencoded = byte_pair_encode( + let mut reencoded = util::byte_pair_encode( &unstable_bytes[..unstable_bytes.len() - last_decoded.1], &self.encoder, ); - reencoded.extend(byte_pair_encode( + reencoded.extend(util::byte_pair_encode( &unstable_bytes[unstable_bytes.len() - last_decoded.1..], &self.encoder, )); @@ -439,108 +369,8 @@ impl CoreBPE { (tokens, completions) } -} - -#[pymethods] -impl CoreBPE { - #[new] - fn new( - encoder: HashMap, usize>, - special_tokens_encoder: HashMap, - pattern: &str, - ) -> PyResult { - let regex = Regex::new(pattern) - .map_err(|e| PyErr::new::(e.to_string()))?; - - let special_regex = { - let _parts = special_tokens_encoder - .keys() - .map(|s| fancy_regex::escape(s)) - .collect::>(); - Regex::new(&_parts.join("|")) - .map_err(|e| PyErr::new::(e.to_string()))? - }; - - let decoder: HashMap> = - encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); - - assert!(encoder.len() == decoder.len()); - - let special_tokens_decoder: HashMap> = special_tokens_encoder - .iter() - .map(|(k, v)| (*v, k.as_bytes().to_vec())) - .collect(); - - // Clone because I don't know how to tell Rust I'm not going to change the map - let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); - sorted_token_bytes.sort(); - - Ok(CoreBPE { - encoder, - special_tokens_encoder, - decoder, - special_tokens_decoder, - regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), - special_regex_tls: (0..MAX_NUM_THREADS) - .map(|_| special_regex.clone()) - .collect(), - sorted_token_bytes, - }) - } - - // ==================== - // Encoding - // ==================== - - fn encode_ordinary(&self, py: Python, text: &str) -> Vec { - py.allow_threads(|| self._encode_ordinary_native(text)) - } - fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { - py.allow_threads(|| self._encode_native(text, &allowed_special).0) - } - - fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { - py.allow_threads(|| { - match std::str::from_utf8(bytes) { - Ok(text) => self._encode_ordinary_native(text), - Err(e) => { - let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; - let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); - let (mut tokens, last_piece_token_len) = - self._increase_last_piece_token_len(tokens, last_piece_token_len); - if !tokens.is_empty() && last_piece_token_len > 0 { - // Lop off the tokens from the last piece and run BPE on the remaining bytes - // Somewhat niche, but this may not be correct if we'd have had a regex - // split between the valid UTF-8 and the invalid bytes, which is why this - // method is private - let mut unstable_bytes = - self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); - unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); - - tokens.truncate(tokens.len() - last_piece_token_len); - tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); - } - tokens - } - } - }) - } - - fn encode_with_unstable( - &self, - py: Python, - text: &str, - allowed_special: HashSet<&str>, - ) -> Py { - let (tokens, completions) = - py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); - let py_completions = - PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); - (tokens, py_completions).into_py(py) - } - - fn encode_single_token(&self, piece: &[u8]) -> PyResult { + pub fn encode_single_token(&self, piece: &[u8]) -> Result> { if let Some(token) = self.encoder.get(piece).copied() { return Ok(token); } @@ -549,66 +379,80 @@ impl CoreBPE { return Ok(token); } } - Err(PyErr::new::(piece.to_owned())) + Err(piece.to_owned()) } fn encode_single_piece(&self, piece: &[u8]) -> Vec { if let Some(token) = self.encoder.get(piece) { return vec![*token]; } - byte_pair_encode(piece, &self.encoder) + util::byte_pair_encode(piece, &self.encoder) } // ==================== // Decoding // ==================== - fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { - let bytes = py.allow_threads(|| self._decode_native(&tokens)); - PyBytes::new(py, &bytes).into() - } - - fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + pub fn decode_single_token_bytes(&self, token: usize) -> Result<&[u8], String> { if let Some(bytes) = self.decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + return Ok(bytes); } if let Some(bytes) = self.special_tokens_decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + return Ok(bytes); } - Err(PyErr::new::(token.to_string())) + Err(token.to_string()) } // ==================== // Miscellaneous // ==================== - fn token_byte_values(&self, py: Python) -> Vec> { - self.sorted_token_bytes - .iter() - .map(|x| PyBytes::new(py, x).into()) - .collect() + pub fn token_byte_values(&self) -> &Vec> { + &self.sorted_token_bytes } -} -#[pymodule] -fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - Ok(()) -} + pub fn new( + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> Result { + let regex = Regex::new(pattern)?; + // .map_err(|e| PyErr::new::(e.to_string()))?; -#[cfg(test)] -mod tests { - use rustc_hash::FxHashMap as HashMap; + let special_regex = { + let _parts = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>(); + Regex::new(&_parts.join("|"))? - use crate::byte_pair_split; + // .map_err(|e| PyErr::new::(e.to_string()))? + }; - #[test] - fn very_simple_test() { - let mut ranks = HashMap::default(); - ranks.insert(b"ab".to_vec(), 1); - ranks.insert(b"cd".to_vec(), 2); + let decoder: HashMap> = + encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + + assert!(encoder.len() == decoder.len()); - let res = byte_pair_split(b"abcd", &ranks); - assert_eq!(res, vec![b"ab", b"cd"]); + let special_tokens_decoder: HashMap> = special_tokens_encoder + .iter() + .map(|(k, v)| (*v, k.as_bytes().to_vec())) + .collect(); + + // Clone because I don't know how to tell Rust I'm not going to change the map + let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); + sorted_token_bytes.sort(); + + Ok(CoreBPENative { + encoder, + special_tokens_encoder, + decoder, + special_tokens_decoder, + regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + special_regex_tls: (0..MAX_NUM_THREADS) + .map(|_| special_regex.clone()) + .collect(), + sorted_token_bytes, + }) } -} +} \ No newline at end of file diff --git a/core/src/load.rs b/core/src/load.rs new file mode 100644 index 00000000..975f5fcd --- /dev/null +++ b/core/src/load.rs @@ -0,0 +1,168 @@ + +use rustc_hash::FxHashMap as HashMap; +use std::{env, path::PathBuf}; +use sha1::{Sha1, Digest}; +use std::error::Error; +use json; + +type Result = std::result::Result>; + +fn read_file(blobpath: &str) -> Result> { + // TODO: support blobs? + + if !(blobpath.starts_with("http") || blobpath.starts_with("https")) { + return Ok(std::fs::read(blobpath)?); + } + + Ok(reqwest::blocking::get(blobpath)?.bytes()?.to_vec()) +} + +fn get_tiktoken_cache_dir() -> PathBuf { + match env::var_os("TIKTOKEN_CACHE_DIR") { + Some(v) => PathBuf::from(v), + None => { + match env::var_os("DATA_GYM_CACHE_DIR") { + Some(v) => PathBuf::from(v), + None => { + let mut temp_dir = env::temp_dir(); + temp_dir.push("data-gym-cache"); + + temp_dir + } + } + } + } +} + +fn sha1_as_hex(s: &str) -> String { + let mut hasher = Sha1::new(); + hasher.update(s.as_bytes()); + let result = hasher.finalize(); + + format!("{:x}", result) +} + +fn read_file_cached(blobpath: &str) -> Result> { + let mut cache_path = get_tiktoken_cache_dir(); + + if !cache_path.exists() { + std::fs::create_dir_all(&cache_path)?; + } + + cache_path.push(sha1_as_hex(blobpath)); + + println!("cache_path: {:?}", cache_path); + + if cache_path.exists() { + let catch_path_str = cache_path.into_os_string().into_string() + .or(Err( { + // let cache_path_lossy_str = cache_path.to_string_lossy().to_string(); + // format!("Unable to convert path {cache_path_lossy_str}") + format!("Unable to convert path") + }))?; + return read_file(&catch_path_str); + } + + let content = read_file(blobpath)?; + + std::fs::write(cache_path, &content)?; + + Ok(content) +} + +fn is_printable(u: u8) -> bool { + // printable ascii characters according to python + !(u <= 31 || (u >= 127 && u <= 160) || u == 173) +} + +pub fn data_gym_to_mergeable_bpe_ranks(vocab_bpe_file: &str, encoder_json_file: &str) -> Result, usize>> { + let mut rank_to_intbyte = (0..=255) + .filter(|x| is_printable(*x) && (*x as char) != ' ') + .collect::>(); + + let mut data_gym_byte_to_byte = rank_to_intbyte + .iter() + .map(|&x| (x as u32, x)) + .collect::>(); + + let mut n = 0; + for b in 0..=255 { + if !rank_to_intbyte.contains(&b) { + rank_to_intbyte.push(b); + data_gym_byte_to_byte.insert(256 + n, b); + n += 1; + } + } + assert!(rank_to_intbyte.len() == 256); + + // vocab_bpe contains the merges along with associated ranks + let cached_vocab = read_file_cached(vocab_bpe_file)?; + let vocab_bpe_contents = std::str::from_utf8(&cached_vocab)? + .split("\n").collect::>(); + + let bpe_merges = match vocab_bpe_contents[1..(vocab_bpe_contents.len() - 1)] + .iter() + .map(|&s| s.split_whitespace()) + .map(|mut sp| match (sp.next(), sp.next()) { + (Some(a), Some(b)) => Some((a, b)), + _ => None, + }) + .collect::>>() + { + Some(v) => v, + None => return Err("Unable to parse vocab_bpe file".into()), + }; + + let decode_data_gym = + |value: &str| value.chars().map(|c| { + data_gym_byte_to_byte[&(c as u32)] + } ).collect::>(); + + // # add the single byte tokens + let mut bpe_ranks = + rank_to_intbyte + .iter() + .enumerate() + .map(|(i, b)| (vec![*b], i)) + .collect::, usize>>(); + + // add the merged tokens + let mut n = bpe_ranks.len(); + for (first, second) in bpe_merges { + bpe_ranks.insert([decode_data_gym(first), decode_data_gym(second)].concat(), n); + n += 1; + } + + // check that the encoder file matches the merges file + // this sanity check is important since tiktoken assumes that ranks are ordered the same + // as merge priority + let cached_encoder = read_file_cached(encoder_json_file)?; + let encoder_json = json::parse(&std::str::from_utf8(&cached_encoder)?)?; + + let mut encoder_json_loaded = encoder_json.entries() + .map(|(k, v)| (decode_data_gym(k), v.as_usize().unwrap())) + .collect::, usize>>(); + + // drop these two special tokens if present, since they're not mergeable bpe tokens + encoder_json_loaded.remove(&decode_data_gym("<|endoftext|>")); + encoder_json_loaded.remove(&decode_data_gym("<|startoftext|>")); + + assert!(bpe_ranks == encoder_json_loaded); + + Ok(bpe_ranks) +} + +pub fn load_tiktoken_bpe(tiktoken_bpe_file: &str) -> Result, usize>> { + use base64::{engine::general_purpose, Engine as _}; + + let content = read_file_cached(tiktoken_bpe_file)?; + + Ok(std::str::from_utf8(&content)? + .lines() + .filter(|s| s.len() > 0) + .map(|s| s.split_whitespace()) + .map(|mut sp| (sp.next().unwrap(), sp.next().unwrap())) + .map(|(first, second)| (general_purpose::STANDARD.decode(&first).unwrap(), second.parse::().unwrap())) + .collect::, usize>>()) +} + diff --git a/core/src/openai_public.rs b/core/src/openai_public.rs new file mode 100644 index 00000000..24e0ab99 --- /dev/null +++ b/core/src/openai_public.rs @@ -0,0 +1,125 @@ + +use rustc_hash::FxHashMap as HashMap; +use std::error::Error; +use std::sync::RwLock; +use json; + +#[path = "load.rs"] +mod load; + +type Result = std::result::Result>; + +lazy_static! { + pub static ref REGISTRY: HashMap = { + json::parse(include_str!("../../tiktoken/registry.json")) + .expect("Failed to parse internal JSON") + .entries() + .map(|(key, value)| { + let loading_strategy = if value.has_key("data_gym_to_mergeable_bpe_ranks") { + EncoderLoadingStrategy::DataGym( + DataGymDef { + vocab_bpe_file: value["data_gym_to_mergeable_bpe_ranks"]["vocab_bpe_file"].as_str().expect("error").into(), + encoder_json_file: value["data_gym_to_mergeable_bpe_ranks"]["encoder_json_file"].as_str().expect("error").into() + }) + } + else if value.has_key("load_tiktoken_bpe") { + EncoderLoadingStrategy::BPE(value["load_tiktoken_bpe"].as_str().expect("fail").into()) + } + else { + panic!("Invalid encoding"); + }; + + EncodingLazy::new( + key.into(), + value["explicit_n_vocab"].as_usize(), + value["pat_str"].as_str().expect("foo").into(), + value["special_tokens"].entries() + .map(|(key, value)| (key.into(), value.as_usize().expect("foo"))) + .collect::>(), + loading_strategy + ) + }) + + .map(|enc| (enc.name.clone(), enc)) + .collect::>() + }; + + pub static ref MODEL_TO_ENCODING: HashMap = + json::parse(include_str!("../../tiktoken/model_to_encoding.json")) + .expect("Failed to parse internal JSON") + .entries() + .map(|(k, v)| (k.into(), v.as_str().expect("foo").into())) + .collect::>(); +} + +#[derive(Clone, PartialEq, Eq, Hash)] +struct DataGymDef { + vocab_bpe_file: String, + encoder_json_file: String, +} + +#[derive(Clone, PartialEq, Eq, Hash)] +enum EncoderLoadingStrategy { + BPE(String), + DataGym(DataGymDef), +} + +pub struct EncodingLazy { + name: String, + explicit_n_vocab: Option, + pub pat_str: String, + pub special_tokens: HashMap, + mergeable_ranks: RwLock, usize>>>, + loading_strategy: EncoderLoadingStrategy, +} + +fn load_bpe(path: &str) -> Result, usize>> { + load::load_tiktoken_bpe(path) +} + +fn load_data_gym(def: &DataGymDef) -> Result, usize>> { + load::data_gym_to_mergeable_bpe_ranks(&def.vocab_bpe_file, &def.encoder_json_file) +} + +// #[memoize] +fn load_mergeable_ranks(loading_strategy: &EncoderLoadingStrategy) -> Result, usize>> +{ + match loading_strategy { + EncoderLoadingStrategy::BPE(path) => load_bpe(&path), + EncoderLoadingStrategy::DataGym(def) => load_data_gym(&def), + } +} + +impl EncodingLazy { + fn new(name: String, + explicit_n_vocab: Option, + pat_str: String, + special_tokens: HashMap, + loading_strategy: EncoderLoadingStrategy) -> Self { + EncodingLazy { + name, + explicit_n_vocab, + pat_str, + special_tokens, + mergeable_ranks: RwLock::new(None), + loading_strategy + } + } + + pub fn get(&self) -> Result, usize>> { + { + let read = self.mergeable_ranks.read().unwrap(); + if read.is_some() { + return Ok(read.as_ref().unwrap().clone()); + } + } + + let mut write = self.mergeable_ranks.write().unwrap(); + *write = Some(load_mergeable_ranks(&self.loading_strategy)?); + + Ok(write.as_ref().unwrap().clone()) + } +} + + + diff --git a/core/src/util.rs b/core/src/util.rs new file mode 100644 index 00000000..b9605a18 --- /dev/null +++ b/core/src/util.rs @@ -0,0 +1,136 @@ +use rustc_hash::FxHashMap as HashMap; + +fn _byte_pair_merge( + piece: &[u8], + ranks: &HashMap, usize>, + f: impl Fn(std::ops::Range) -> T, +) -> Vec { + // This is a vector of (start, rank). + // The rank is of the byte pair starting at position start. + // The rank of the last item in the vector is not a valid value. + let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); + + // NOTE: using a macro here because a closure fails to get inlined + // according to optimization remarks. + // A closure also cannot capture a reference to `piece` without + // the borrow checker complaining about the mutable borrows during + // the assignments later in this code. + macro_rules! get_rank { + ($start_idx:expr, $skip:expr) => {{ + let start_idx: usize = $start_idx; + let skip: usize = $skip; + if (start_idx + skip + 2) < parts.len() { + ranks + .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) + .map(|r| *r) + } else { + None + } + }}; + ($idx:expr) => {{ + get_rank!($idx, 0) + }}; + } + + // We look up the ranks once in the beggining and iteratively update + // them during each merge, which reduces the number of rank lookups. + for i in 0..parts.len() - 2 { + match get_rank!(i) { + Some(rank) => { + // usize::MAX is a sentinel value and cannot be a valid rank + debug_assert!(rank != usize::MAX); + parts[i].1 = rank; + } + None => { + continue; + } + }; + } + + // If you have n parts and m merges, this does O(mn) work. + // We could do something with a heap and do O(m log n) work. + // It is important to consider that n is often small (<100), and as such + // the cache-locality benefits outweigh the algorithmic complexity downsides + // of the `parts` vector data structure above. + + // Note that we hash bytes, not token pairs. As long as we train BPE the way we + // currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + loop { + if parts.len() == 1 { + break; + } + + // usize::MAX is a sentinel rank value allowing us to + // take the min more quickly + let mut min_rank: (usize, usize) = (usize::MAX, 0); + for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { + if rank < min_rank.0 { + min_rank = (rank, i); + } + } + + if min_rank.0 != usize::MAX { + let i = min_rank.1; + + // NOTE: We are about to remove parts[i + 1]. We do not do it + // yet because there are cache-locality benefits to updating + // parts[i] and parts[i-1] before removing, which could thrash + // the cache. Thus, we update the rank calculation by skipping over + // parts[i + 1], by invoking `get_rank!` with `skip = 1`. + parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX); + if i > 0 { + parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX); + } + + parts.remove(i + 1); + } else { + break; + } + } + let mut out: Vec = Vec::with_capacity(parts.len() - 1); + for i in 0..parts.len() - 1 { + out.push(f(parts[i].0..parts[i + 1].0)); + } + out +} + +pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { + if piece.len() == 1 { + return vec![ranks[piece]]; + } + _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) +} + +pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { + if piece.len() == 1 { + return vec![piece]; + } + _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) +} + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashMap as HashMap; + + use crate::util::_byte_pair_merge; + pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { + if piece.len() == 1 { + return vec![piece]; + } + _byte_pair_merge(piece, ranks) + .iter() + .map(|p| &piece[p.start..p.end]) + .collect() + } + + #[test] + fn very_simple_test() { + let mut ranks = HashMap::default(); + ranks.insert(b"ab".to_vec(), 1); + ranks.insert(b"cd".to_vec(), 2); + + let res = byte_pair_split(b"abcd", &ranks); + assert_eq!(res, vec![b"ab", b"cd"]); + } +} \ No newline at end of file diff --git a/java/pom.xml b/java/pom.xml new file mode 100644 index 00000000..61cbf01c --- /dev/null +++ b/java/pom.xml @@ -0,0 +1,99 @@ + + + + 4.0.0 + + com.openai + tiktoken + 1.0-SNAPSHOT + + tiktoken + https://github.com/openai/tiktoken + jar + + + UTF-8 + 1.7 + 1.7 + + + + + junit + junit + 4.11 + test + + + org.scijava + native-lib-loader + 2.4.0 + + + + + + + ${project.basedir}/../natives/ + ${project.build.directory}/classes/natives/ + + + + + + + maven-clean-plugin + 3.1.0 + + + + maven-resources-plugin + 3.0.2 + + + maven-compiler-plugin + 3.8.0 + + + maven-surefire-plugin + 2.22.1 + + + maven-jar-plugin + 3.0.2 + + + maven-install-plugin + 2.5.2 + + + maven-deploy-plugin + 2.8.2 + + + + maven-site-plugin + 3.7.1 + + + maven-project-info-reports-plugin + 3.0.0 + + + org.apache.maven.plugins + maven-failsafe-plugin + 2.22.1 + + + + integration-test + verify + + + + + + + + diff --git a/java/src/main/java/tiktoken/Encoding.java b/java/src/main/java/tiktoken/Encoding.java new file mode 100644 index 00000000..1773225d --- /dev/null +++ b/java/src/main/java/tiktoken/Encoding.java @@ -0,0 +1,34 @@ +package tiktoken; + +import org.scijava.nativelib.NativeLoader; +import java.io.IOException; + +public class Encoding implements AutoCloseable +{ + static { + try { + // load from JAR + NativeLoader.loadLibrary("_tiktoken_jni"); + } + catch(IOException e) { + throw new RuntimeException(e); + } + } + + // initialized by init + private long handle; + + private native void init(String modelName); + + private native void destroy(); + + public native long[] encode(String text, String[] allowedSpecialTokens, long maxTokenLength); + + public Encoding(String modelName) { + this.init(modelName); + } + + public void close() throws Exception { + destroy(); + } +} diff --git a/java/src/test/java/tiktoken/EncodingTestIT.java b/java/src/test/java/tiktoken/EncodingTestIT.java new file mode 100644 index 00000000..602a1ef9 --- /dev/null +++ b/java/src/test/java/tiktoken/EncodingTestIT.java @@ -0,0 +1,21 @@ +package tiktoken; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; + +// run test: mvn failsafe:integration-test +public class EncodingTestIT +{ + @Test + public void shouldAnswerWithTrue() throws Exception + { + Encoding encoding = new Encoding("text-davinci-001"); + + long[] a = encoding.encode("test", new String[0], 0); + + encoding.close(); + + assertArrayEquals(new long[] {9288}, a); + } +} diff --git a/jni/Cargo.toml b/jni/Cargo.toml new file mode 100644 index 00000000..7c6d4155 --- /dev/null +++ b/jni/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "tiktoken_jni" +version = "0.2.0" +edition = "2021" +rust-version = "1.57.0" + +[lib] +name = "_tiktoken_jni" +crate-type = ["cdylib"] + +[dependencies] +tiktoken_core = { path = "../core" } +rustc-hash = "1.1.0" +jni = "0.20.0" + +[build-dependencies] +json = "0.12.4" diff --git a/jni/build.rs b/jni/build.rs new file mode 100644 index 00000000..9c866413 --- /dev/null +++ b/jni/build.rs @@ -0,0 +1,7 @@ +use json; + +fn main() { + json::parse(include_str!("../tiktoken/registry.json")).expect("Failed to parse internal JSON"); + json::parse(include_str!("../tiktoken/model_to_encoding.json")).expect("Failed to parse internal JSON"); + println!("JSON Parsing validated"); +} diff --git a/jni/src/lib.rs b/jni/src/lib.rs new file mode 100644 index 00000000..6bd99d6d --- /dev/null +++ b/jni/src/lib.rs @@ -0,0 +1,114 @@ +use std::collections::HashSet; +use std::sync::MutexGuard; + +use _tiktoken_core::openai_public::EncodingLazy; +use jni::JNIEnv; +// These objects are what you should use as arguments to your native +// function. They carry extra lifetime information to prevent them escaping +// this context and getting used after being GC'd. +use jni::objects::{JObject, JString}; + +// This is just a pointer. We'll be returning it from our function. We +// can't return one of the objects with lifetime information because the +// lifetime checker won't let us. +use jni::sys::{jarray, jlong}; + +use _tiktoken_core::{self, CoreBPENative}; + +type Result = std::result::Result>; + +fn unwrap_or_throw(env: &JNIEnv, result: Result, default: T) -> T { + // Check if an exception is already thrown + if env.exception_check().expect("exception_check() failed") { + return default; + } + + match result { + Ok(tokenizer) => tokenizer, + Err(error) => { + let exception_class = env + .find_class("java/lang/Exception") + .expect("Unable to find exception class"); + env.throw_new(exception_class, format!("{}", error)) + .expect("Unable to throw exception"); + default + } + } +} + +#[no_mangle] +pub extern "system" fn Java_tiktoken_Encoding_init(env: JNIEnv, obj: JObject, model_name: JString) { + let result = || -> Result<()> { + // First, we have to get the string out of Java. Check out the `strings` + // module for more info on how this works. + let model_name: String = env + .get_string(model_name)? + .into(); + + let encoding_name = _tiktoken_core::openai_public::MODEL_TO_ENCODING + .get(&model_name).ok_or("Unable to find model")?; + + let encoding = _tiktoken_core::openai_public::REGISTRY + .get(encoding_name).ok_or("Unable to find encoding")?; + + let bpe_native = CoreBPENative::new( + encoding.get()?, + encoding.special_tokens.clone(), + &encoding.pat_str, + )?; + + Ok(unsafe { + env.set_rust_field(obj, "handle", bpe_native)?; + }) + }(); + + unwrap_or_throw(&env, result, ()) +} + +#[no_mangle] +pub extern "system" fn Java_tiktoken_Encoding_destroy(env: JNIEnv, obj: JObject) { + unsafe { + let _: CoreBPENative = env.take_rust_field(obj, "handle").expect("Unable to get handle during destruction"); + } +} + +#[no_mangle] +pub extern "system" fn Java_tiktoken_Encoding_encode( + env: JNIEnv, + obj: JObject, + text: JString, + allowed_special_tokens: jarray, + max_token_length: jlong, +) -> jarray { + let result = || -> Result { + let encoding: MutexGuard = unsafe { env.get_rust_field(obj, "handle")? }; + + let enc = encoding; + let input: String = env + .get_string(text)? + .into(); + + let len = env.get_array_length(allowed_special_tokens)?; + let mut strings: Vec = Vec::with_capacity(len as usize); + for i in 0..len { + let element: JObject = env + .get_object_array_element(allowed_special_tokens, i)?; + let current: String = env.get_string(element.into())?.into(); + strings.push(current); + } + + let v2: HashSet<&str> = strings.iter().map(|s| &**s).collect(); + + let (tokens, _, _) = enc._encode_native(&input, &v2, Some(max_token_length as usize)); + + let output = env + .new_long_array(tokens.len().try_into()?)?; + + let array_of_u64 = tokens.iter().map(|x| *x as i64).collect::>(); + env.set_long_array_region(output, 0, array_of_u64.as_slice())?; + + Ok(output) + }(); + + unwrap_or_throw(&env, result, JObject::null().into_raw()) +} diff --git a/python/Cargo.toml b/python/Cargo.toml new file mode 100644 index 00000000..7febd473 --- /dev/null +++ b/python/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "tiktoken" +version = "0.2.0" +edition = "2021" +rust-version = "1.57.0" + +[lib] +name = "_tiktoken" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.17.3", features = ["extension-module"] } +tiktoken_core = { path = "../core" } +rustc-hash = "1.1.0" diff --git a/python/src/lib.rs b/python/src/lib.rs new file mode 100644 index 00000000..599105e2 --- /dev/null +++ b/python/src/lib.rs @@ -0,0 +1,97 @@ +#![allow(clippy::borrow_deref_ref)] + +use std::collections::HashSet; + +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyList, PyTuple}; +use pyo3::PyResult; +use rustc_hash::FxHashMap as HashMap; + +use _tiktoken_core::CoreBPENative; + +#[pyclass] +struct CoreBPE { + native: CoreBPENative, +} + +#[pymethods] +impl CoreBPE { + #[new] + fn new( + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> PyResult { + let native = CoreBPENative::new(encoder, special_tokens_encoder, pattern) + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(CoreBPE { native }) + } + + // ==================== + // Encoding + // ==================== + + fn encode_ordinary(&self, py: Python, text: &str) -> Vec { + py.allow_threads(|| self.native._encode_ordinary_native(text)) + } + + fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { + py.allow_threads(|| self.native._encode_native(text, &allowed_special, None).0) + } + + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + py.allow_threads(|| { + self.native._encode_bytes(bytes) + }) + } + + fn encode_with_unstable( + &self, + py: Python, + text: &str, + allowed_special: HashSet<&str>, + ) -> Py { + let (tokens, completions) = + py.allow_threads(|| self.native._encode_unstable_native(text, &allowed_special)); + let py_completions = + PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + (tokens, py_completions).into_py(py) + } + + fn encode_single_token(&self, piece: &[u8]) -> PyResult { + self.native.encode_single_token(piece).map_err(|e| PyErr::new::(e)) + } + + // ==================== + // Decoding + // ==================== + + fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { + let bytes = py.allow_threads(|| self.native._decode_native(&tokens)); + PyBytes::new(py, &bytes).into() + } + + fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + self.native.decode_single_token_bytes(token).map(|bytes| PyBytes::new(py, &bytes).into()) + .map_err(|e| PyErr::new::(e)) + } + + // ==================== + // Miscellaneous + // ==================== + + fn token_byte_values(&self, py: Python) -> Vec> { + self.native.token_byte_values() + .iter() + .map(|x| PyBytes::new(py, x).into()) + .collect() + } +} + + +#[pymodule] +fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/setup.py b/setup.py index a22e8e5d..246487b0 100644 --- a/setup.py +++ b/setup.py @@ -7,12 +7,14 @@ RustExtension( "tiktoken._tiktoken", binding=Binding.PyO3, + path="python/Cargo.toml", # Between our use of editable installs and wanting to use Rust for performance sensitive # code, it makes sense to just always use --release debug=False, ) ], - package_data={"tiktoken": ["py.typed"]}, - packages=["tiktoken", "tiktoken_ext"], + include_package_data=True, + package_data={ "tiktoken": ["py.typed", "registry.json", "model_to_encoding.json"] }, + packages=["tiktoken"], zip_safe=False, ) diff --git a/tiktoken/load.py b/tiktoken/load.py index c5881068..5537ecf4 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -73,6 +73,7 @@ def decode_data_gym(value: str) -> bytes: # add the single byte tokens bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)} + # add the merged tokens n = len(bpe_ranks) for first, second in bpe_merges: diff --git a/tiktoken/model.py b/tiktoken/model.py index 66e9e046..b3d3ba59 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -2,47 +2,16 @@ from .core import Encoding from .registry import get_encoding +import json -# TODO: this will likely be replaced by an API endpoint -MODEL_TO_ENCODING: dict[str, str] = { - # text - "text-davinci-003": "p50k_base", - "text-davinci-002": "p50k_base", - "text-davinci-001": "r50k_base", - "text-curie-001": "r50k_base", - "text-babbage-001": "r50k_base", - "text-ada-001": "r50k_base", - "davinci": "r50k_base", - "curie": "r50k_base", - "babbage": "r50k_base", - "ada": "r50k_base", - # code - "code-davinci-002": "p50k_base", - "code-davinci-001": "p50k_base", - "code-cushman-002": "p50k_base", - "code-cushman-001": "p50k_base", - "davinci-codex": "p50k_base", - "cushman-codex": "p50k_base", - # edit - "text-davinci-edit-001": "p50k_edit", - "code-davinci-edit-001": "p50k_edit", - # embeddings - "text-embedding-ada-002": "cl100k_base", - # old embeddings - "text-similarity-davinci-001": "r50k_base", - "text-similarity-curie-001": "r50k_base", - "text-similarity-babbage-001": "r50k_base", - "text-similarity-ada-001": "r50k_base", - "text-search-davinci-doc-001": "r50k_base", - "text-search-curie-doc-001": "r50k_base", - "text-search-babbage-doc-001": "r50k_base", - "text-search-ada-doc-001": "r50k_base", - "code-search-babbage-code-001": "r50k_base", - "code-search-ada-code-001": "r50k_base", - # open source - "gpt2": "gpt2", -} +try: + import importlib.resources as pkg_resources +except ImportError: + # Try backported to PY<37 `importlib_resources`. + import importlib_resources as pkg_resources +# TODO: this will likely be replaced by an API endpoint +MODEL_TO_ENCODING: dict[str, str] = json.loads(pkg_resources.read_text("tiktoken", "model_to_encoding.json")) def encoding_for_model(model_name: str) -> Encoding: try: diff --git a/tiktoken/model_to_encoding.json b/tiktoken/model_to_encoding.json new file mode 100644 index 00000000..2b82312a --- /dev/null +++ b/tiktoken/model_to_encoding.json @@ -0,0 +1,33 @@ +{ + "text-davinci-003": "p50k_base", + "text-davinci-002": "p50k_base", + "text-davinci-001": "r50k_base", + "text-curie-001": "r50k_base", + "text-babbage-001": "r50k_base", + "text-ada-001": "r50k_base", + "davinci": "r50k_base", + "curie": "r50k_base", + "babbage": "r50k_base", + "ada": "r50k_base", + "code-davinci-002": "p50k_base", + "code-davinci-001": "p50k_base", + "code-cushman-002": "p50k_base", + "code-cushman-001": "p50k_base", + "davinci-codex": "p50k_base", + "cushman-codex": "p50k_base", + "text-davinci-edit-001": "p50k_edit", + "code-davinci-edit-001": "p50k_edit", + "text-embedding-ada-002": "cl100k_base", + "text-similarity-davinci-001": "r50k_base", + "text-similarity-curie-001": "r50k_base", + "text-similarity-babbage-001": "r50k_base", + "text-similarity-ada-001": "r50k_base", + "text-search-davinci-doc-001": "r50k_base", + "text-search-curie-doc-001": "r50k_base", + "text-search-babbage-doc-001": "r50k_base", + "text-search-ada-doc-001": "r50k_base", + "code-search-babbage-code-001": "r50k_base", + "code-search-ada-code-001": "r50k_base", + "gpt2": "gpt2", + "gpt-3.5-turbo": "cl100k_base" +} \ No newline at end of file diff --git a/tiktoken/registry.json b/tiktoken/registry.json new file mode 100644 index 00000000..aa3ee530 --- /dev/null +++ b/tiktoken/registry.json @@ -0,0 +1,50 @@ +{ + "gpt2": { + "data_gym_to_mergeable_bpe_ranks": { + "vocab_bpe_file": "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", + "encoder_json_file": "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json" + }, + "explicit_n_vocab": 50257, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + "special_tokens": { + "<|endoftext|>": 50256 + } + }, + "r50k_base": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken", + "explicit_n_vocab": 50257, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + "special_tokens": { + "<|endoftext|>": 50256 + } + }, + "p50k_base": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + "explicit_n_vocab": 50281, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + "special_tokens": { + "<|endoftext|>": 50256 + } + }, + "p50k_edit": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken", + "special_tokens": { + "<|endoftext|>": 50256, + "<|fim_prefix|>": 50281, + "<|fim_middle|>": 50282, + "<|fim_suffix|>": 50283 + }, + "pat_str": "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+" + }, + "cl100k_base": { + "load_tiktoken_bpe": "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken", + "special_tokens": { + "<|endoftext|>": 100257, + "<|fim_prefix|>": 100258, + "<|fim_middle|>": 100259, + "<|fim_suffix|>": 100260, + "<|endofprompt|>": 100276 + }, + "pat_str": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + } +} \ No newline at end of file diff --git a/tiktoken/registry.py b/tiktoken/registry.py index 52d8ec2d..0a55d27e 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -3,46 +3,32 @@ import importlib import pkgutil import threading +import json from typing import Any, Callable, Optional -import tiktoken_ext - from tiktoken.core import Encoding +from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe _lock = threading.RLock() ENCODINGS: dict[str, Encoding] = {} -ENCODING_CONSTRUCTORS: Optional[dict[str, Callable[[], dict[str, Any]]]] = None - +ENCODING_DEFS: dict[str, Any] = None -def _find_constructors() -> None: - global ENCODING_CONSTRUCTORS - with _lock: - if ENCODING_CONSTRUCTORS is not None: - return - ENCODING_CONSTRUCTORS = {} +def _load_encoding_defs(): + global ENCODING_DEFS + if not ENCODING_DEFS is None: + return ENCODING_DEFS - # tiktoken_ext is a namespace package - # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes - # - we use namespace package pattern so `pkgutil.iter_modules` is fast - # - it's a separate top-level package because namespace subpackages of non-namespace - # packages don't quite do what you want with editable installs - plugin_mods = pkgutil.iter_modules(tiktoken_ext.__path__, tiktoken_ext.__name__ + ".") + try: + import importlib.resources as pkg_resources + except ImportError: + # Try backported to PY<37 `importlib_resources`. + import importlib_resources as pkg_resources - for _, mod_name, _ in plugin_mods: - mod = importlib.import_module(mod_name) - try: - constructors = mod.ENCODING_CONSTRUCTORS - except AttributeError as e: - raise ValueError( - f"tiktoken plugin {mod_name} does not define ENCODING_CONSTRUCTORS" - ) from e - for enc_name, constructor in constructors.items(): - if enc_name in ENCODING_CONSTRUCTORS: - raise ValueError( - f"Duplicate encoding name {enc_name} in tiktoken plugin {mod_name}" - ) - ENCODING_CONSTRUCTORS[enc_name] = constructor + # read registry.json + # note: was trying to place it into /data/registry.json but python packaging is always unhappy + ENCODING_DEFS = json.loads(pkg_resources.read_text("tiktoken", "registry.json")) + return ENCODING_DEFS def get_encoding(encoding_name: str) -> Encoding: if encoding_name in ENCODINGS: @@ -52,22 +38,26 @@ def get_encoding(encoding_name: str) -> Encoding: if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] - if ENCODING_CONSTRUCTORS is None: - _find_constructors() - assert ENCODING_CONSTRUCTORS is not None - - if encoding_name not in ENCODING_CONSTRUCTORS: + _load_encoding_defs() + if encoding_name not in ENCODING_DEFS: raise ValueError(f"Unknown encoding {encoding_name}") - constructor = ENCODING_CONSTRUCTORS[encoding_name] - enc = Encoding(**constructor()) + encoding_def = dict(ENCODING_DEFS[encoding_name]) + encoding_def["name"] = encoding_name + + if "load_tiktoken_bpe" in encoding_def: + encoding_def["mergeable_ranks"] = load_tiktoken_bpe(encoding_def["load_tiktoken_bpe"]) + del encoding_def["load_tiktoken_bpe"] + elif "data_gym_to_mergeable_bpe_ranks" in encoding_def: + encoding_def["mergeable_ranks"] = data_gym_to_mergeable_bpe_ranks(**encoding_def["data_gym_to_mergeable_bpe_ranks"]) + del encoding_def["data_gym_to_mergeable_bpe_ranks"] + else: + raise ValueError(f"Unknown loader {encoding_name}") + enc = Encoding(**encoding_def) ENCODINGS[encoding_name] = enc return enc def list_encoding_names() -> list[str]: with _lock: - if ENCODING_CONSTRUCTORS is None: - _find_constructors() - assert ENCODING_CONSTRUCTORS is not None - return list(ENCODING_CONSTRUCTORS) + return list(_load_encoding_defs().keys()) diff --git a/tiktoken_ext/openai_public.py b/tiktoken_ext/openai_public.py deleted file mode 100644 index 522d58fb..00000000 --- a/tiktoken_ext/openai_public.py +++ /dev/null @@ -1,88 +0,0 @@ -from tiktoken.load import data_gym_to_mergeable_bpe_ranks, load_tiktoken_bpe - -ENDOFTEXT = "<|endoftext|>" -FIM_PREFIX = "<|fim_prefix|>" -FIM_MIDDLE = "<|fim_middle|>" -FIM_SUFFIX = "<|fim_suffix|>" -ENDOFPROMPT = "<|endofprompt|>" - - -def gpt2(): - mergeable_ranks = data_gym_to_mergeable_bpe_ranks( - vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe", - encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json", - ) - return { - "name": "gpt2", - "explicit_n_vocab": 50257, - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": {"<|endoftext|>": 50256}, - } - - -def r50k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken" - ) - return { - "name": "r50k_base", - "explicit_n_vocab": 50257, - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": {ENDOFTEXT: 50256}, - } - - -def p50k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" - ) - return { - "name": "p50k_base", - "explicit_n_vocab": 50281, - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": {ENDOFTEXT: 50256}, - } - - -def p50k_edit(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken" - ) - special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283} - return { - "name": "p50k_edit", - "pat_str": r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - - -def cl100k_base(): - mergeable_ranks = load_tiktoken_bpe( - "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken" - ) - special_tokens = { - ENDOFTEXT: 100257, - FIM_PREFIX: 100258, - FIM_MIDDLE: 100259, - FIM_SUFFIX: 100260, - ENDOFPROMPT: 100276, - } - return { - "name": "cl100k_base", - "pat_str": r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - - -ENCODING_CONSTRUCTORS = { - "gpt2": gpt2, - "r50k_base": r50k_base, - "p50k_base": p50k_base, - "cl100k_base": cl100k_base, - "p50k_edit": p50k_edit, -}