surprisal.utils
utility functions supporting model and surprisal classes
1""" 2utility functions supporting model and surprisal classes 3""" 4 5import tokenizers 6from transformers import tokenization_utils_base 7 8 9def hf_pick_matching_token_ixs( 10 encoding: "tokenizers.Encoding", span_of_interest: slice, span_type: str 11) -> slice: 12 """Picks token indices in a tokenized encoded sequence that best correspond to 13 a substring of interest in the original sequence, given by a char span (slice) 14 Args: 15 encoding (transformers.tokenization_utils_base.BatchEncoding): the output of a 16 `tokenizer(text)` call on a single text instance (not a batch, i.e. 17 `tokenizer([text])`). 18 span_of_interest (slice): a `slice` object denoting the character indices in the 19 original `text` string we want to extract the corresponding tokens for 20 span_type (str): either `char` or `word`, denoting what type of span we are interested 21 in obtaining. this argument has no default to ensure the user is aware of what 22 kind of span they are getting from this function 23 Returns: 24 slice: the start and stop indices of **tokens** within an encoded sequence that 25 best match the `span_of_interest` 26 """ 27 span_of_interest = slice( 28 span_of_interest.start or 0, 29 span_of_interest.stop 30 or ( 31 len(encoding.ids) 32 if span_type == "word" 33 else encoding.token_to_chars((len(encoding.ids) - 1)).end 34 ), 35 span_of_interest.step, 36 ) 37 38 start_token = 0 39 end_token = len(encoding.ids) 40 for i, _ in enumerate(encoding.ids): 41 span = encoding.token_to_chars(i) 42 word = encoding.token_to_word(i) 43 # batchencoding 0 gives access to the encoded string 44 45 if span is None or word is None: 46 # for [CLS], no span is returned 47 # log(f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"', 48 # type="WARN", cmap="WARN", verbosity_check=True) 49 continue 50 span = tokenization_utils_base.CharSpan(*span) 51 52 if span_type == "char": 53 if span.start <= span_of_interest.start: 54 start_token = i 55 if span.end >= span_of_interest.stop: 56 end_token = i + 1 57 break 58 elif span_type == "word": 59 if word < span_of_interest.start: 60 start_token = i + 1 61 # watch out for the semantics of the "stop" 62 if word == span_of_interest.stop: 63 end_token = i 64 elif word > span_of_interest.stop: 65 break 66 67 assert end_token - start_token <= len( 68 encoding.ids 69 ), "Extracted span is larger than original span" 70 71 return slice(start_token, end_token) 72 73 74openai_models_list = [ 75 "davinci-instruct-beta", 76 "babbage", 77 "text-similarity-ada-001", 78 "babbage-code-search-code", 79 "code-davinci-edit-001", 80 "ada", 81 "ada-similarity", 82 "babbage-search-query", 83 "text-search-curie-query-001", 84 "babbage-search-document", 85 "davinci-search-document", 86 "text-curie-001", 87 "text-similarity-babbage-001", 88 "text-similarity-curie-001", 89 "code-search-ada-text-001", 90 "text-search-ada-doc-001", 91 "audio-transcribe-001", 92 "text-search-curie-doc-001", 93 "curie-similarity", 94 "ada-search-document", 95 "text-davinci-insert-001", 96 "text-search-davinci-doc-001", 97 "ada-search-query", 98 "text-search-ada-query-001", 99 "text-davinci-001", 100 "curie", 101 "curie-instruct-beta", 102 "babbage-similarity", 103 "ada-code-search-text", 104 "davinci-similarity", 105 "text-search-davinci-query-001", 106 "babbage-code-search-text", 107 "code-search-babbage-code-001", 108 "text-davinci-002", 109 "text-davinci-003", 110 "text-ada-001", 111 "davinci-search-query", 112 "ada-code-search-code", 113 "curie-search-document", 114 "text-similarity-davinci-001", 115 "text-davinci-insert-002", 116 "code-search-babbage-text-001", 117 "text-davinci-edit-001", 118 "text-search-babbage-query-001", 119 "davinci", 120 "text-search-babbage-doc-001", 121 "curie-search-query", 122 "text-babbage-001", 123 "code-search-ada-code-001", 124 "cushman:2020-05-03", 125 "ada:2020-05-03", 126 "babbage:2020-05-03", 127 "curie:2020-05-03", 128 "davinci:2020-05-03", 129 "if-davinci-v2", 130 "if-curie-v2", 131 "if-davinci:3.0.0", 132 "davinci-if:3.0.0", 133 "davinci-instruct-beta:2.0.0", 134 "text-ada:001", 135 "text-davinci:001", 136 "text-curie:001", 137 "text-babbage:001", 138]
10def hf_pick_matching_token_ixs( 11 encoding: "tokenizers.Encoding", span_of_interest: slice, span_type: str 12) -> slice: 13 """Picks token indices in a tokenized encoded sequence that best correspond to 14 a substring of interest in the original sequence, given by a char span (slice) 15 Args: 16 encoding (transformers.tokenization_utils_base.BatchEncoding): the output of a 17 `tokenizer(text)` call on a single text instance (not a batch, i.e. 18 `tokenizer([text])`). 19 span_of_interest (slice): a `slice` object denoting the character indices in the 20 original `text` string we want to extract the corresponding tokens for 21 span_type (str): either `char` or `word`, denoting what type of span we are interested 22 in obtaining. this argument has no default to ensure the user is aware of what 23 kind of span they are getting from this function 24 Returns: 25 slice: the start and stop indices of **tokens** within an encoded sequence that 26 best match the `span_of_interest` 27 """ 28 span_of_interest = slice( 29 span_of_interest.start or 0, 30 span_of_interest.stop 31 or ( 32 len(encoding.ids) 33 if span_type == "word" 34 else encoding.token_to_chars((len(encoding.ids) - 1)).end 35 ), 36 span_of_interest.step, 37 ) 38 39 start_token = 0 40 end_token = len(encoding.ids) 41 for i, _ in enumerate(encoding.ids): 42 span = encoding.token_to_chars(i) 43 word = encoding.token_to_word(i) 44 # batchencoding 0 gives access to the encoded string 45 46 if span is None or word is None: 47 # for [CLS], no span is returned 48 # log(f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"', 49 # type="WARN", cmap="WARN", verbosity_check=True) 50 continue 51 span = tokenization_utils_base.CharSpan(*span) 52 53 if span_type == "char": 54 if span.start <= span_of_interest.start: 55 start_token = i 56 if span.end >= span_of_interest.stop: 57 end_token = i + 1 58 break 59 elif span_type == "word": 60 if word < span_of_interest.start: 61 start_token = i + 1 62 # watch out for the semantics of the "stop" 63 if word == span_of_interest.stop: 64 end_token = i 65 elif word > span_of_interest.stop: 66 break 67 68 assert end_token - start_token <= len( 69 encoding.ids 70 ), "Extracted span is larger than original span" 71 72 return slice(start_token, end_token)
Picks token indices in a tokenized encoded sequence that best correspond to
a substring of interest in the original sequence, given by a char span (slice)
Args:
encoding (transformers.tokenization_utils_base.BatchEncoding): the output of a
tokenizer(text)
call on a single text instance (not a batch, i.e.
tokenizer([text])
).
span_of_interest (slice): a slice
object denoting the character indices in the
original text
string we want to extract the corresponding tokens for
span_type (str): either char
or word
, denoting what type of span we are interested
in obtaining. this argument has no default to ensure the user is aware of what
kind of span they are getting from this function
Returns:
slice: the start and stop indices of tokens within an encoded sequence that
best match the span_of_interest