Skip to content

Commit c98f700

Browse files
Add OpenAI MGD support
1 parent eaa4bd5 commit c98f700

File tree

10 files changed

+452
-75
lines changed

10 files changed

+452
-75
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,17 @@ pytest tests/multilspy
162162

163163
## 5. Monitor-Guided Decoding
164164

165-
A monitor under the Monitor-Guided Decoding framework, is instantiated using `multilspy` as the LSP client, and as a logits-processor to guide the LM decoding. [src/monitors4codegen/monitor_guided_decoding/monitor.py](src/monitors4codegen/monitor_guided_decoding/monitor.py) provides the class `MGDLogitsProcessor` which can be used with any HuggingFace Language Model, as a `LogitsProcessor` to guide the LM using MGD. [src/monitors4codegen/monitor_guided_decoding/dereferences_monitor.py](src/monitors4codegen/monitor_guided_decoding/dereferences_monitor.py) provides the instantiation for dereferences monitor. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor.
165+
A monitor under the Monitor-Guided Decoding framework, is instantiated using `multilspy` as the LSP client, and provides maskgen to guide the LM decoding. The monitor interface is defined as class `Monitor` in file [src/monitors4codegen/monitor_guided_decoding/monitor.py](src/monitors4codegen/monitor_guided_decoding/monitor.py). The interface is implemented by various monitors supporting different properties like valid identifier dereferences, valid number of arguments, valid typestate method calls, etc.
166+
167+
### MGD with HuggingFace models
168+
[src/monitors4codegen/monitor_guided_decoding/hf_gen.py](src/monitors4codegen/monitor_guided_decoding/hf_gen.py) provides the class `MGDLogitsProcessor` which can be used with any HuggingFace Language Model, as a [`LogitsProcessor`](https://huggingface.co/docs/transformers/internal/generation_utils#logitsprocessor) to guide the LM using MGD. Example uses with [SantaCoder](https://huggingface.co/bigcode/santacoder) model are available in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py).
169+
170+
### MGD with OpenAI models
171+
[src/monitors4codegen/monitor_guided_decoding/openai_gen.py](src/monitors4codegen/monitor_guided_decoding/openai_gen.py) provides the method `openai_mgd` which takes the prompt and a `Monitor` as input, and returns the MGD guided generation using an OpenAI model.
172+
173+
### Monitors
174+
#### Dereferences Monitor
175+
[src/monitors4codegen/monitor_guided_decoding/dereferences_monitor.py](src/monitors4codegen/monitor_guided_decoding/dereferences_monitor.py) provides the instantiation of `Monitor` class for dereferences monitor. It can be used to guide LMs to generate valid identifier dereferences. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor.
166176

167177
## Contributing
168178

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"jedi-language-server==0.41.1",
2929
"pydantic==1.10.5",
3030
"code-tokenize==0.2.0",
31+
"openai==1.3.3",
3132
"torch==1.12.0",
3233
"transformers==4.30.0",
3334
"tiktoken==0.3.3",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pytest==7.3.1
1111
pydantic==1.10.5
1212
pytest-asyncio==0.21.1
1313
pygtrie==2.5.0
14+
openai==1.3.3
1415
code-tokenize==0.2.0
1516
--extra-index-url https://download.pytorch.org/whl/cu113
1617
torch==1.12.0+cu113
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Provides the definition of a monitor as per the Monitor-Guided Decoding framework
3+
"""
4+
5+
import asyncio
6+
import torch
7+
8+
from asyncio.events import AbstractEventLoop
9+
from typing import List, Union
10+
from transformers import LogitsProcessor
11+
12+
class MGDLogitsProcessor(LogitsProcessor):
13+
"""
14+
Provides the logits processor for monitor guided decoding
15+
"""
16+
17+
loop: AbstractEventLoop
18+
19+
def __init__(self, monitors: List[Monitor], loop: Union[None, AbstractEventLoop] = None) -> None:
20+
super().__init__()
21+
22+
if loop is None:
23+
self.loop = asyncio.get_event_loop()
24+
else:
25+
self.loop = loop
26+
27+
self.monitors: List[Monitor] = monitors
28+
29+
async def process_scores_for_single_input_id(
30+
self, segment_idx: int, input_ids: torch.LongTensor, scores: torch.FloatTensor
31+
) -> torch.FloatTensor:
32+
"""
33+
Asynchronously processes the scores for a single input id using the MGD framework
34+
"""
35+
blacklisted_ids: List[int] = await self.monitors[segment_idx].maskgen(input_ids.tolist())
36+
output_scores: torch.FloatTensor = torch.where(
37+
torch.tensor([True if i in blacklisted_ids else False for i in range(scores.shape[0])]).to(scores.device),
38+
float("-inf") * torch.ones(scores.shape[0]).to(scores.device),
39+
scores,
40+
).to(scores)
41+
return output_scores
42+
43+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
44+
"""
45+
This method is called by the HuggingFace decoder, for every token generation with
46+
the input_ids (seen so far including prompt) and scores (for the next token).
47+
This method processes the scores using the MGD framework.
48+
"""
49+
assert len(input_ids.shape) == 2
50+
assert input_ids.shape[0] == len(self.monitors)
51+
assert len(scores.shape) == 2
52+
53+
async def f(input_ids_arg: torch.LongTensor, scores_arg: torch.FloatTensor):
54+
new_score_coroutines = [
55+
self.process_scores_for_single_input_id(i, input_ids_arg[i], scores_arg[i])
56+
for i in range(input_ids_arg.shape[0])
57+
]
58+
new_scores = await asyncio.gather(*new_score_coroutines)
59+
return tuple(new_scores)
60+
61+
future = asyncio.run_coroutine_threadsafe(f(input_ids, scores), self.loop)
62+
results = future.result()
63+
new_scores = torch.stack(results, dim=0).to(scores)
64+
return new_scores

src/monitors4codegen/monitor_guided_decoding/monitor.py

Lines changed: 1 addition & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,13 @@
22
Provides the definition of a monitor as per the Monitor-Guided Decoding framework
33
"""
44

5-
import asyncio
6-
import torch
7-
8-
from asyncio.events import AbstractEventLoop
9-
from typing import List, Tuple, Union
10-
from transformers import LogitsProcessor
5+
from typing import List, Tuple
116
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
127
from monitors4codegen.multilspy import LanguageServer
138
from monitors4codegen.multilspy.multilspy_config import Language
149
from dataclasses import dataclass
1510
from monitors4codegen.multilspy.multilspy_utils import TextUtils
1611

17-
1812
@dataclass
1913
class MonitorFileBuffer:
2014
"""
@@ -83,58 +77,3 @@ def update(self, generated_token: str):
8377
This function updates the state of the monitor, given the generated token.
8478
"""
8579
raise NotImplementedError()
86-
87-
88-
class MGDLogitsProcessor(LogitsProcessor):
89-
"""
90-
Provides the logits processor for monitor guided decoding
91-
"""
92-
93-
loop: AbstractEventLoop
94-
95-
def __init__(self, monitors: List[Monitor], loop: Union[None, AbstractEventLoop] = None) -> None:
96-
super().__init__()
97-
98-
if loop is None:
99-
self.loop = asyncio.get_event_loop()
100-
else:
101-
self.loop = loop
102-
103-
self.monitors: List[Monitor] = monitors
104-
105-
async def process_scores_for_single_input_id(
106-
self, segment_idx: int, input_ids: torch.LongTensor, scores: torch.FloatTensor
107-
) -> torch.FloatTensor:
108-
"""
109-
Asynchronously processes the scores for a single input id using the MGD framework
110-
"""
111-
blacklisted_ids: List[int] = await self.monitors[segment_idx].maskgen(input_ids.tolist())
112-
output_scores: torch.FloatTensor = torch.where(
113-
torch.tensor([True if i in blacklisted_ids else False for i in range(scores.shape[0])]).to(scores.device),
114-
float("-inf") * torch.ones(scores.shape[0]).to(scores.device),
115-
scores,
116-
).to(scores)
117-
return output_scores
118-
119-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
120-
"""
121-
This method is called by the HuggingFace decoder, for every token generation with
122-
the input_ids (seen so far including prompt) and scores (for the next token).
123-
This method processes the scores using the MGD framework.
124-
"""
125-
assert len(input_ids.shape) == 2
126-
assert input_ids.shape[0] == len(self.monitors)
127-
assert len(scores.shape) == 2
128-
129-
async def f(input_ids_arg: torch.LongTensor, scores_arg: torch.FloatTensor):
130-
new_score_coroutines = [
131-
self.process_scores_for_single_input_id(i, input_ids_arg[i], scores_arg[i])
132-
for i in range(input_ids_arg.shape[0])
133-
]
134-
new_scores = await asyncio.gather(*new_score_coroutines)
135-
return tuple(new_scores)
136-
137-
future = asyncio.run_coroutine_threadsafe(f(input_ids, scores), self.loop)
138-
results = future.result()
139-
new_scores = torch.stack(results, dim=0).to(scores)
140-
return new_scores
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""
2+
This module provides the functions and classes for running Monitor-Guided Decoding over OpenAI models
3+
"""
4+
5+
from enum import Enum
6+
import time
7+
from typing import List, Set
8+
import torch
9+
import asyncio
10+
11+
from openai import OpenAI
12+
from monitors4codegen.monitor_guided_decoding.monitor import Monitor
13+
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TikTokenWrapper
14+
15+
class OpenAI_Models(Enum):
16+
TD3 = 'text-davinci-003'
17+
18+
def openai_mgd(
19+
client: OpenAI,
20+
model: OpenAI_Models,
21+
tokenizer: TikTokenWrapper,
22+
prompt_tokenized: torch.Tensor,
23+
temp: float,
24+
top_p: float,
25+
monitor: Monitor,
26+
num_new_tokens: int
27+
):
28+
"""
29+
This function generates completions with OpenAI models using the Monitor-Guided Decoding scheme.
30+
"""
31+
prompt_tokenized: torch.Tensor = torch.tensor(prompt_tokenized, dtype=torch.int64)
32+
assert len(prompt_tokenized.shape) == 1
33+
34+
all_tokens: torch.Tensor = prompt_tokenized
35+
gen_text: bytes = b''
36+
37+
gen_tokens: List[int] = []
38+
39+
tokens_sort_key = {k:[0, 0] for k in tokenizer.all_token_ids}
40+
41+
# # TODO: Find a way to prioritize tokens to be blacklisted
42+
# # 1. The following code uses info about whether has a break char in it
43+
# for token, token_id in tokenizer.vocab_trie.iteritems():
44+
# if token[0] in monitor.all_break_chars:
45+
# tokens_sort_key[token_id][0] = 0 # ".", ", a"
46+
# elif any([c in monitor.all_break_chars for c in token]):
47+
# tokens_sort_key[token_id][0] = 1 # "abc, "
48+
# else:
49+
# tokens_sort_key[token_id][0] = 2
50+
51+
# # 2. The following code uses frequency of the token in repo as a heuristic
52+
# for freq_token, freq in metadata_batch[seq_idx]['token_freq']:
53+
# tokens_sort_key[freq_token][1] = freq
54+
55+
# # 3. Use a local-small and very fast language model to score the tokens
56+
57+
# # 4. Use the prompt to score the tokens
58+
59+
all_text_bytes: bytes = tokenizer.tokenizer.decode_bytes(all_tokens.tolist())
60+
prompt_num_tokens: int = all_tokens.shape[0]
61+
62+
priority_blacklist: List[int] = []
63+
64+
while all_tokens.shape[0] < prompt_num_tokens + num_new_tokens:
65+
num_toks_to_gen = (prompt_num_tokens + num_new_tokens) - all_tokens.shape[0]
66+
67+
blacklisted_ids: List[int] = asyncio.run_coroutine_threadsafe(monitor.maskgen(all_tokens.tolist()), monitor.monitor_file_buffer.lsp.server.loop).result()
68+
white_listed_ids: Set[int] = set(tokenizer.all_token_ids) - set(blacklisted_ids+[50256])
69+
70+
logit_bias = {50256:-100}
71+
72+
for token_id in priority_blacklist:
73+
logit_bias[token_id] = -100
74+
75+
if len(white_listed_ids) <= (300 - len(logit_bias)):
76+
for white_token_id in white_listed_ids:
77+
logit_bias[white_token_id] = 100
78+
else:
79+
for candidate_token in sorted(blacklisted_ids, key=lambda x: tokens_sort_key[x], reverse=True):
80+
if len(logit_bias) >= 300:
81+
break
82+
if candidate_token in blacklisted_ids:
83+
logit_bias[candidate_token] = -100
84+
85+
exponential_backoff_wait = 1
86+
while True:
87+
try:
88+
prompt_arg: str = all_text_bytes.decode('utf-8', errors='strict')
89+
except UnicodeDecodeError:
90+
prompt_arg: List[int] = all_tokens.tolist()
91+
92+
try:
93+
response = client.completions.create(
94+
model=model.value,
95+
prompt=[prompt_arg],
96+
temperature=temp,
97+
max_tokens=num_toks_to_gen if len(logit_bias) <= 1 else 1,
98+
top_p=top_p,
99+
stop=['.'],
100+
logit_bias=logit_bias,
101+
logprobs=5
102+
)
103+
break
104+
except Exception:
105+
time.sleep(exponential_backoff_wait)
106+
if exponential_backoff_wait < 64:
107+
exponential_backoff_wait = exponential_backoff_wait*2
108+
else:
109+
exponential_backoff_wait = 1
110+
111+
assert len(response.choices) == 1
112+
113+
def convert_bytesrep_to_bytes(x: str) -> bytes:
114+
if x.startswith('bytes:'):
115+
return bytes.fromhex(x.replace('bytes:', '').replace('\\x', ''))
116+
else:
117+
return x.encode()
118+
119+
tokens_gen_bytes_ = list(map(convert_bytesrep_to_bytes, response.choices[0].logprobs.tokens))
120+
tokens_gen_bytes = []
121+
dot_found = False
122+
for token_bytes in tokens_gen_bytes_:
123+
gen_text += token_bytes
124+
all_text_bytes += token_bytes
125+
tokens_gen_bytes.append(token_bytes)
126+
if b'.' in token_bytes:
127+
dot_found = True
128+
break
129+
130+
should_manually_add_dot = None
131+
if response.choices[0].finish_reason == 'stop':
132+
if dot_found:
133+
should_manually_add_dot = False
134+
else:
135+
should_manually_add_dot = True
136+
elif response.choices[0].finish_reason == 'length':
137+
should_manually_add_dot = False
138+
else:
139+
raise Exception("Unknown finish reason", response.choices[0].finish_reason)
140+
141+
tokens_gen = list(map(lambda x: tokenizer.tokenizer.encode_single_token(x), tokens_gen_bytes))
142+
143+
assert should_manually_add_dot is not None
144+
if should_manually_add_dot:
145+
gen_text += b'.'
146+
all_text_bytes += b'.'
147+
tokens_gen.append(tokenizer.tokenizer.encode_single_token('.'))
148+
149+
if len(logit_bias) > 1:
150+
assert len(tokens_gen) == 1, (print(response), response, launch_debug(locals()))
151+
if tokens_gen[0] in blacklisted_ids:
152+
priority_blacklist.append(tokens_gen[0])
153+
continue
154+
priority_blacklist = []
155+
156+
new_all_tokens = torch.cat([
157+
all_tokens,
158+
torch.tensor(tokens_gen)
159+
]).to(all_tokens)
160+
161+
assert len(new_all_tokens.shape) == 1
162+
assert new_all_tokens.shape[0] > all_tokens.shape[0], (new_all_tokens.shape, all_tokens.shape, launch_debug(locals()))
163+
assert torch.equal(new_all_tokens[:all_tokens.shape[0]], all_tokens)
164+
gen_tokens += new_all_tokens[all_tokens.shape[0]:].tolist()
165+
all_tokens = new_all_tokens
166+
167+
return gen_tokens, gen_text.decode()

src/monitors4codegen/monitor_guided_decoding/tokenizer_wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import tiktoken
88

9-
from typing import List, Union
9+
from typing import List, Set, Union
1010
from pygtrie import CharTrie
1111
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
1212

@@ -97,11 +97,11 @@ def __init__(self, tokenizer: tiktoken.core.Encoding):
9797
self.vocab_trie[decoded_token] = v
9898
self.all_token_ids.add(v)
9999

100-
def decode(self, token_ids: torch.Tensor, *args, **kwargs) -> str:
100+
def decode(self, token_ids: Union[List[int], torch.Tensor], *args, **kwargs) -> str:
101101
"""
102102
Decodes the given token ids to a string
103103
"""
104-
token_ids, clean_up_tokenization_spaces, skip_special_tokens = None, None, None
104+
clean_up_tokenization_spaces, skip_special_tokens = None, None
105105
if len(args) == 0:
106106
pass
107107
elif len(args) == 1:
@@ -116,10 +116,10 @@ def decode(self, token_ids: torch.Tensor, *args, **kwargs) -> str:
116116

117117
assert not clean_up_tokenization_spaces
118118
assert skip_special_tokens
119-
assert isinstance(token_ids, torch.Tensor)
120-
token_ids = token_ids.tolist()
119+
if isinstance(token_ids, torch.Tensor):
120+
token_ids = token_ids.tolist()
121121

122-
token_ids = [i for i in token_ids if i not in self.all_special_ids]
122+
token_ids: List[int] = [i for i in token_ids if i not in self.all_special_ids]
123123

124124
return self.tokenizer.decode(token_ids)
125125

0 commit comments

Comments
 (0)