Skip to content

Commit 0224edc

Browse files
Add Switch-Enum monitor
1 parent 62d6c24 commit 0224edc

File tree

5 files changed

+390
-0
lines changed

5 files changed

+390
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu
174174
#### Dereferences Monitor
175175
[src/monitors4codegen/monitor_guided_decoding/monitors/dereferences_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/dereferences_monitor.py) provides the instantiation of `Monitor` class for dereferences monitor. It can be used to guide LMs to generate valid identifier dereferences. Unit tests for the dereferences monitor are present in [tests/monitor_guided_decoding/test_dereferences_monitor_java.py](tests/monitor_guided_decoding/test_dereferences_monitor_java.py), which also provide usage examples for the dereferences monitor.
176176

177+
#### Switch-Enum Monitor
178+
[src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor.
179+
177180
## Contributing
178181

179182
This project welcomes contributions and suggestions. Most contributions require you to agree to a

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"jedi-language-server==0.41.1",
2929
"pydantic==1.10.5",
3030
"code-tokenize==0.2.0",
31+
"code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0",
3132
"openai==1.3.3",
3233
"torch==1.12.0",
3334
"transformers==4.30.0",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ pytest-asyncio==0.21.1
1313
pygtrie==2.5.0
1414
openai==1.3.3
1515
code-tokenize==0.2.0
16+
code-ast @ git+https://github.com/cedricrupb/code_ast@982940d04b1d721e5ac9a97d433f36d1fb47e8e0
1617
--extra-index-url https://download.pytorch.org/whl/cu113
1718
torch==1.12.0+cu113
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
This module provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement
3+
"""
4+
5+
from typing import List
6+
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor, DecoderStates
7+
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
8+
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import TokenizerWrapper
9+
from monitors4codegen.multilspy.multilspy_utils import TextUtils
10+
from monitors4codegen.multilspy import multilspy_types
11+
12+
class SwitchEnumMonitor(DereferencesMonitor):
13+
"""
14+
Provides the switch-enum monitor, that is invoked when "case " is typed in a switch statement to provide
15+
enum values as completions
16+
"""
17+
def __init__(self, tokenizer: TokenizerWrapper, monitor_file_buffer: MonitorFileBuffer, responsible_for_file_buffer_state: bool = True) -> None:
18+
super().__init__(tokenizer, monitor_file_buffer, responsible_for_file_buffer_state)
19+
self.all_break_chars.remove('.')
20+
21+
async def pre(self) -> None:
22+
cursor_idx = TextUtils.get_index_from_line_col(
23+
self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path),
24+
self.monitor_file_buffer.current_lc[0],
25+
self.monitor_file_buffer.current_lc[1],
26+
)
27+
text_upto_cursor = self.monitor_file_buffer.lsp.get_open_file_text(self.monitor_file_buffer.file_path)[
28+
:cursor_idx
29+
]
30+
31+
# TODO: pre can be improved by checking for r"switch.*case", and obtaining completions, and then prefixing a whitespace
32+
if not text_upto_cursor.endswith("case "):
33+
self.decoder_state = DecoderStates.S0
34+
return
35+
36+
completions = await self.a_phi()
37+
if len(completions) == 0:
38+
self.decoder_state = DecoderStates.S0
39+
else:
40+
self.decoder_state = DecoderStates.Constrained
41+
self.legal_completions = completions
42+
43+
async def a_phi(self) -> List[str]:
44+
relative_file_path = self.monitor_file_buffer.file_path
45+
line, column = self.monitor_file_buffer.current_lc
46+
47+
with self.monitor_file_buffer.lsp.open_file(relative_file_path):
48+
legal_completions = await self.monitor_file_buffer.lsp.request_completions(
49+
relative_file_path, line, column
50+
)
51+
legal_completions = [
52+
completion["completionText"]
53+
for completion in legal_completions
54+
if completion["kind"] == multilspy_types.CompletionItemKind.EnumMember
55+
]
56+
57+
return legal_completions
58+
59+
async def update(self, generated_token: str):
60+
"""
61+
Updates the monitor state based on the generated token
62+
"""
63+
if self.responsible_for_file_buffer_state:
64+
self.monitor_file_buffer.append_text(generated_token)
65+
if self.decoder_state == DecoderStates.Constrained:
66+
for break_char in self.all_break_chars:
67+
if break_char in generated_token:
68+
self.decoder_state = DecoderStates.S0
69+
self.legal_completions = None
70+
return
71+
72+
# No breaking characters found. Continue in constrained state
73+
self.legal_completions = [
74+
legal_completion[len(generated_token) :]
75+
for legal_completion in self.legal_completions
76+
if legal_completion.startswith(generated_token)
77+
]
78+
else:
79+
# Nothing to be done in other states
80+
return

0 commit comments

Comments
 (0)