surprisal.utils

utility functions supporting model and surprisal classes

  1"""
  2utility functions supporting model and surprisal classes
  3"""
  4
  5import tokenizers
  6from transformers import tokenization_utils_base
  7
  8
  9def hf_pick_matching_token_ixs(
 10    encoding: "tokenizers.Encoding", span_of_interest: slice, span_type: str
 11) -> slice:
 12    """Picks token indices in a tokenized encoded sequence that best correspond to
 13        a substring of interest in the original sequence, given by a char span (slice)
 14    Args:
 15        encoding (transformers.tokenization_utils_base.BatchEncoding): the output of a
 16            `tokenizer(text)` call on a single text instance (not a batch, i.e.
 17            `tokenizer([text])`).
 18        span_of_interest (slice): a `slice` object denoting the character indices in the
 19            original `text` string we want to extract the corresponding tokens for
 20        span_type (str): either `char` or `word`, denoting what type of span we are interested
 21            in obtaining. this argument has no default to ensure the user is aware of what
 22            kind of span they are getting from this function
 23    Returns:
 24        slice: the start and stop indices of **tokens** within an encoded sequence that
 25            best match the `span_of_interest`
 26    """
 27    span_of_interest = slice(
 28        span_of_interest.start or 0,
 29        span_of_interest.stop
 30        or (
 31            len(encoding.ids)
 32            if span_type == "word"
 33            else encoding.token_to_chars((len(encoding.ids) - 1)).end
 34        ),
 35        span_of_interest.step,
 36    )
 37
 38    start_token = 0
 39    end_token = len(encoding.ids)
 40    for i, _ in enumerate(encoding.ids):
 41        span = encoding.token_to_chars(i)
 42        word = encoding.token_to_word(i)
 43        # batchencoding 0 gives access to the encoded string
 44
 45        if span is None or word is None:
 46            # for [CLS], no span is returned
 47            # log(f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"',
 48            #      type="WARN", cmap="WARN", verbosity_check=True)
 49            continue
 50        span = tokenization_utils_base.CharSpan(*span)
 51
 52        if span_type == "char":
 53            if span.start <= span_of_interest.start:
 54                start_token = i
 55            if span.end >= span_of_interest.stop:
 56                end_token = i + 1
 57                break
 58        elif span_type == "word":
 59            if word < span_of_interest.start:
 60                start_token = i + 1
 61            # watch out for the semantics of the "stop"
 62            if word == span_of_interest.stop:
 63                end_token = i
 64            elif word > span_of_interest.stop:
 65                break
 66
 67    assert end_token - start_token <= len(
 68        encoding.ids
 69    ), "Extracted span is larger than original span"
 70
 71    return slice(start_token, end_token)
 72
 73
 74openai_models_list = [
 75    "davinci-instruct-beta",
 76    "babbage",
 77    "text-similarity-ada-001",
 78    "babbage-code-search-code",
 79    "code-davinci-edit-001",
 80    "ada",
 81    "ada-similarity",
 82    "babbage-search-query",
 83    "text-search-curie-query-001",
 84    "babbage-search-document",
 85    "davinci-search-document",
 86    "text-curie-001",
 87    "text-similarity-babbage-001",
 88    "text-similarity-curie-001",
 89    "code-search-ada-text-001",
 90    "text-search-ada-doc-001",
 91    "audio-transcribe-001",
 92    "text-search-curie-doc-001",
 93    "curie-similarity",
 94    "ada-search-document",
 95    "text-davinci-insert-001",
 96    "text-search-davinci-doc-001",
 97    "ada-search-query",
 98    "text-search-ada-query-001",
 99    "text-davinci-001",
100    "curie",
101    "curie-instruct-beta",
102    "babbage-similarity",
103    "ada-code-search-text",
104    "davinci-similarity",
105    "text-search-davinci-query-001",
106    "babbage-code-search-text",
107    "code-search-babbage-code-001",
108    "text-davinci-002",
109    "text-davinci-003",
110    "text-ada-001",
111    "davinci-search-query",
112    "ada-code-search-code",
113    "curie-search-document",
114    "text-similarity-davinci-001",
115    "text-davinci-insert-002",
116    "code-search-babbage-text-001",
117    "text-davinci-edit-001",
118    "text-search-babbage-query-001",
119    "davinci",
120    "text-search-babbage-doc-001",
121    "curie-search-query",
122    "text-babbage-001",
123    "code-search-ada-code-001",
124    "cushman:2020-05-03",
125    "ada:2020-05-03",
126    "babbage:2020-05-03",
127    "curie:2020-05-03",
128    "davinci:2020-05-03",
129    "if-davinci-v2",
130    "if-curie-v2",
131    "if-davinci:3.0.0",
132    "davinci-if:3.0.0",
133    "davinci-instruct-beta:2.0.0",
134    "text-ada:001",
135    "text-davinci:001",
136    "text-curie:001",
137    "text-babbage:001",
138]
def hf_pick_matching_token_ixs( encoding: tokenizers.Encoding, span_of_interest: slice, span_type: str) -> slice:
10def hf_pick_matching_token_ixs(
11    encoding: "tokenizers.Encoding", span_of_interest: slice, span_type: str
12) -> slice:
13    """Picks token indices in a tokenized encoded sequence that best correspond to
14        a substring of interest in the original sequence, given by a char span (slice)
15    Args:
16        encoding (transformers.tokenization_utils_base.BatchEncoding): the output of a
17            `tokenizer(text)` call on a single text instance (not a batch, i.e.
18            `tokenizer([text])`).
19        span_of_interest (slice): a `slice` object denoting the character indices in the
20            original `text` string we want to extract the corresponding tokens for
21        span_type (str): either `char` or `word`, denoting what type of span we are interested
22            in obtaining. this argument has no default to ensure the user is aware of what
23            kind of span they are getting from this function
24    Returns:
25        slice: the start and stop indices of **tokens** within an encoded sequence that
26            best match the `span_of_interest`
27    """
28    span_of_interest = slice(
29        span_of_interest.start or 0,
30        span_of_interest.stop
31        or (
32            len(encoding.ids)
33            if span_type == "word"
34            else encoding.token_to_chars((len(encoding.ids) - 1)).end
35        ),
36        span_of_interest.step,
37    )
38
39    start_token = 0
40    end_token = len(encoding.ids)
41    for i, _ in enumerate(encoding.ids):
42        span = encoding.token_to_chars(i)
43        word = encoding.token_to_word(i)
44        # batchencoding 0 gives access to the encoded string
45
46        if span is None or word is None:
47            # for [CLS], no span is returned
48            # log(f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"',
49            #      type="WARN", cmap="WARN", verbosity_check=True)
50            continue
51        span = tokenization_utils_base.CharSpan(*span)
52
53        if span_type == "char":
54            if span.start <= span_of_interest.start:
55                start_token = i
56            if span.end >= span_of_interest.stop:
57                end_token = i + 1
58                break
59        elif span_type == "word":
60            if word < span_of_interest.start:
61                start_token = i + 1
62            # watch out for the semantics of the "stop"
63            if word == span_of_interest.stop:
64                end_token = i
65            elif word > span_of_interest.stop:
66                break
67
68    assert end_token - start_token <= len(
69        encoding.ids
70    ), "Extracted span is larger than original span"
71
72    return slice(start_token, end_token)

Picks token indices in a tokenized encoded sequence that best correspond to a substring of interest in the original sequence, given by a char span (slice) Args: encoding (transformers.tokenization_utils_base.BatchEncoding): the output of a tokenizer(text) call on a single text instance (not a batch, i.e. tokenizer([text])). span_of_interest (slice): a slice object denoting the character indices in the original text string we want to extract the corresponding tokens for span_type (str): either char or word, denoting what type of span we are interested in obtaining. this argument has no default to ensure the user is aware of what kind of span they are getting from this function Returns: slice: the start and stop indices of tokens within an encoded sequence that best match the span_of_interest

openai_models_list = ['davinci-instruct-beta', 'babbage', 'text-similarity-ada-001', 'babbage-code-search-code', 'code-davinci-edit-001', 'ada', 'ada-similarity', 'babbage-search-query', 'text-search-curie-query-001', 'babbage-search-document', 'davinci-search-document', 'text-curie-001', 'text-similarity-babbage-001', 'text-similarity-curie-001', 'code-search-ada-text-001', 'text-search-ada-doc-001', 'audio-transcribe-001', 'text-search-curie-doc-001', 'curie-similarity', 'ada-search-document', 'text-davinci-insert-001', 'text-search-davinci-doc-001', 'ada-search-query', 'text-search-ada-query-001', 'text-davinci-001', 'curie', 'curie-instruct-beta', 'babbage-similarity', 'ada-code-search-text', 'davinci-similarity', 'text-search-davinci-query-001', 'babbage-code-search-text', 'code-search-babbage-code-001', 'text-davinci-002', 'text-davinci-003', 'text-ada-001', 'davinci-search-query', 'ada-code-search-code', 'curie-search-document', 'text-similarity-davinci-001', 'text-davinci-insert-002', 'code-search-babbage-text-001', 'text-davinci-edit-001', 'text-search-babbage-query-001', 'davinci', 'text-search-babbage-doc-001', 'curie-search-query', 'text-babbage-001', 'code-search-ada-code-001', 'cushman:2020-05-03', 'ada:2020-05-03', 'babbage:2020-05-03', 'curie:2020-05-03', 'davinci:2020-05-03', 'if-davinci-v2', 'if-curie-v2', 'if-davinci:3.0.0', 'davinci-if:3.0.0', 'davinci-instruct-beta:2.0.0', 'text-ada:001', 'text-davinci:001', 'text-curie:001', 'text-babbage:001']