diff --git a/pyproject.toml b/pyproject.toml index acc0b3a8..208be6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,10 @@ dependencies = [ "opentelemetry-api>=1.28.2", ] +[project.scripts] +lint = "lint:main" +format = "lint:main" + [tool.uv] dev-dependencies = [ "deptry>=0.14.0", diff --git a/src/lint.py b/src/lint.py new file mode 100644 index 00000000..6ac5fa1c --- /dev/null +++ b/src/lint.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +import os +import sys + + +def raise_err(code: int) -> None: + if code > 0: + sys.exit(1) + + +def main() -> None: + fix = ["--fix"] if "--fix" in sys.argv else [] + raise_err(os.system(" ".join(["ruff", "check", "src"] + fix))) + raise_err(os.system("ruff format src")) + raise_err(os.system("mypy src")) + raise_err(os.system("pyright src")) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index f7271dc7..3d10c0be 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -35,8 +35,8 @@ TypeExpression, TypeName, UnionTypeExpr, - ensure_literal_type, extract_inner_type, + render_literal_type, render_type_expr, ) @@ -167,7 +167,7 @@ def encode_type( in_module: list[ModuleName], permit_unknown_members: bool, ) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: - encoder_name: Optional[str] = None # defining this up here to placate mypy + encoder_name: TypeName | None = None # defining this up here to placate mypy chunks: List[FileContents] = [] if isinstance(type, RiverNotType): return (TypeName("None"), [], [], set()) @@ -234,7 +234,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: and prop.const is not None ].pop() one_of_pending.setdefault( - f"{prefix}OneOf_{discriminator_value}", + f"{render_literal_type(prefix)}OneOf_{discriminator_value}", (discriminator_value, []), )[1].append(oneof_t) @@ -270,12 +270,13 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: oneof_t.properties.keys() ).difference(common_members) encoder_name = TypeName( - f"encode_{ensure_literal_type(type_name)}" + f"encode_{render_literal_type(type_name)}" ) encoder_names.add(encoder_name) + _field_name = render_literal_type(encoder_name) typeddict_encoder.append( f"""\ - {encoder_name}(x) # type: ignore[arg-type] + {_field_name}(x) # type: ignore[arg-type] """.strip() ) if local_discriminators: @@ -299,12 +300,14 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: one_of.append(type_name) chunks.extend(contents) encoder_name = TypeName( - f"encode_{ensure_literal_type(type_name)}" + f"encode_{render_literal_type(type_name)}" ) # TODO(dstewart): Figure out why uncommenting this breaks # generated code # encoder_names.add(encoder_name) - typeddict_encoder.append(f"{encoder_name}(x)") + typeddict_encoder.append( + f"{render_literal_type(encoder_name)}(x)" + ) typeddict_encoder.append( f""" if x[{repr(discriminator_name)}] @@ -317,19 +320,27 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: union = OpenUnionTypeExpr(UnionTypeExpr(one_of)) else: union = UnionTypeExpr(one_of) - chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}")) + chunks.append( + FileContents( + f"{render_literal_type(prefix)} = {render_type_expr(union)}" + ) + ) chunks.append(FileContents("")) if base_model == "TypedDict": - encoder_name = TypeName(f"encode_{prefix}") + encoder_name = TypeName(f"encode_{render_literal_type(prefix)}") encoder_names.add(encoder_name) + _field_name = render_literal_type(encoder_name) + _field_type = ( + f"Callable[[{repr(render_literal_type(prefix))}], Any]" + ) chunks.append( FileContents( "\n".join( [ dedent( f"""\ - {encoder_name}: Callable[[{repr(prefix)}], Any] = ( + {_field_name}: {_field_type} = ( lambda x: """.rstrip() ) @@ -349,7 +360,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: for i, t in enumerate(type.anyOf): type_name, _, contents, _ = encode_type( t, - TypeName(f"{prefix}AnyOf_{i}"), + TypeName(f"{render_literal_type(prefix)}AnyOf_{i}"), base_model, in_module, permit_unknown_members=permit_unknown_members, @@ -366,7 +377,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: match type_name: case ListTypeExpr(inner_type_name): typeddict_encoder.append( - f"encode_{ensure_literal_type(inner_type_name)}(x)" + f"encode_{render_literal_type(inner_type_name)}(x)" ) case DictTypeExpr(_): raise ValueError( @@ -377,7 +388,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: typeddict_encoder.append(repr(const)) case other: typeddict_encoder.append( - f"encode_{ensure_literal_type(other)}(x)" + f"encode_{render_literal_type(other)}(x)" ) if permit_unknown_members: union = OpenUnionTypeExpr(UnionTypeExpr(any_of)) @@ -385,17 +396,18 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: union = UnionTypeExpr(any_of) if is_literal(type): typeddict_encoder = ["x"] - chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}")) + chunks.append( + FileContents(f"{render_literal_type(prefix)} = {render_type_expr(union)}") + ) if base_model == "TypedDict": - encoder_name = TypeName(f"encode_{prefix}") + encoder_name = TypeName(f"encode_{render_literal_type(prefix)}") encoder_names.add(encoder_name) + _field_name = render_literal_type(encoder_name) + _field_type = f"Callable[[{repr(render_literal_type(prefix))}], Any]" chunks.append( FileContents( "\n".join( - [ - f"{encoder_name}: Callable[[{repr(prefix)}], Any] = (" - "lambda x: " - ] + [f"{_field_name}: {_field_type} = (lambda x: "] + typeddict_encoder + [")"] ) @@ -491,7 +503,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: match type_name: case ListTypeExpr(inner_type_name): typeddict_encoder.append( - f"encode_{ensure_literal_type(inner_type_name)}(x)" + f"encode_{render_literal_type(inner_type_name)}(x)" ) case DictTypeExpr(_): raise ValueError( @@ -500,11 +512,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: case LiteralTypeExpr(const): typeddict_encoder.append(repr(const)) case other: - typeddict_encoder.append(f"encode_{ensure_literal_type(other)}(x)") + typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)") return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names) assert type.type == "object", type.type - current_chunks: List[str] = [f"class {prefix}({base_model}):"] + current_chunks: List[str] = [ + f"class {render_literal_type(prefix)}({base_model}):" + ] # For the encoder path, do we need "x" to be bound? # lambda x: ... vs lambda _: {} needs_binding = False @@ -519,7 +533,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: typeddict_encoder.append(f"{repr(name)}:") type_name, _, contents, _ = encode_type( prop, - TypeName(prefix + name.title()), + TypeName(prefix.value + name.title()), base_model, in_module, permit_unknown_members=permit_unknown_members, @@ -531,17 +545,19 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: typeddict_encoder.append("'not implemented'") elif isinstance(prop, RiverUnionType): encoder_name = TypeName( - f"encode_{ensure_literal_type(type_name)}" + f"encode_{render_literal_type(type_name)}" ) encoder_names.add(encoder_name) - typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])") + typeddict_encoder.append( + f"{render_literal_type(encoder_name)}(x[{repr(name)}])" + ) if name not in type.required: typeddict_encoder.append( f"if {repr(name)} in x and x[{repr(name)}] else None" ) elif isinstance(prop, RiverIntersectionType): encoder_name = TypeName( - f"encode_{ensure_literal_type(type_name)}" + f"encode_{render_literal_type(type_name)}" ) encoder_names.add(encoder_name) typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])") @@ -552,11 +568,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: safe_name = name if prop.type == "object" and not prop.patternProperties: encoder_name = TypeName( - f"encode_{ensure_literal_type(type_name)}" + f"encode_{render_literal_type(type_name)}" ) encoder_names.add(encoder_name) typeddict_encoder.append( - f"{encoder_name}(x[{repr(safe_name)}])" + f"{render_literal_type(encoder_name)}(x[{repr(safe_name)}])" ) if name not in prop.required: typeddict_encoder.append( @@ -582,14 +598,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: match type_name: case ListTypeExpr(inner_type_name): encoder_name = TypeName( - f"encode_{ensure_literal_type(inner_type_name)}" + f"encode_{render_literal_type(inner_type_name)}" ) encoder_names.add(encoder_name) typeddict_encoder.append( dedent( f"""\ [ - {encoder_name}(y) + {render_literal_type(encoder_name)}(y) for y in x[{repr(name)}] ] """.rstrip() @@ -679,8 +695,10 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: if base_model == "TypedDict": binding = "x" if needs_binding else "_" - encoder_name = TypeName(f"encode_{prefix}") + encoder_name = TypeName(f"encode_{render_literal_type(prefix)}") encoder_names.add(encoder_name) + _field_name = render_literal_type(encoder_name) + _field_type = f"Callable[[{repr(render_literal_type(prefix))}], Any]" current_chunks.insert( 0, FileContents( @@ -688,7 +706,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: [ dedent( f"""\ - {encoder_name}: Callable[[{repr(prefix)}], Any] = ( + {_field_name}: {_field_type} = ( lambda {binding}: """ ) @@ -847,7 +865,7 @@ def __init__(self, client: river.Client[Any]): f"lambda xs: [encode_{init_type_name}(x) for x in xs]" ) else: - render_init_method = f"encode_{ensure_literal_type(init_type)}" + render_init_method = f"encode_{render_literal_type(init_type)}" else: render_init_method = f"""\ lambda x: TypeAdapter({render_type_expr(init_type)}) @@ -870,11 +888,11 @@ def __init__(self, client: river.Client[Any]): case ListTypeExpr(input_type_name): render_input_method = f"""\ lambda xs: [ - encode_{ensure_literal_type(input_type_name)}(x) for x in xs + encode_{render_literal_type(input_type_name)}(x) for x in xs ] """ else: - render_input_method = f"encode_{ensure_literal_type(input_type)}" + render_input_method = f"encode_{render_literal_type(input_type)}" else: render_input_method = f"""\ lambda x: TypeAdapter({render_type_expr(input_type)}) @@ -957,9 +975,9 @@ async def {name}( f"""\ async def {name}( self, - init: {init_type}, + init: {render_type_expr(init_type)}, inputStream: AsyncIterable[{render_type_expr(input_type)}], - ) -> {output_type}: + ) -> {render_type_expr(output_type)}: return await self.client.send_upload( {repr(schema_name)}, {repr(name)}, @@ -1069,8 +1087,11 @@ async def {name}( existing = emitted_files.get(file_path, FileContents(FILE_HEADER)) emitted_files[file_path] = FileContents("\n".join([existing] + contents)) + def render_names(xs: set[TypeName]) -> str: + return ", ".join(sorted(render_literal_type(x) for x in xs)) + rendered_imports = [ - f"from .{dotted_modules} import {', '.join(sorted(names))}" + f"from .{dotted_modules} import {render_names(names)}" for dotted_modules, names in imports.items() ] diff --git a/src/replit_river/codegen/typing.py b/src/replit_river/codegen/typing.py index 87c577bf..5808a2d2 100644 --- a/src/replit_river/codegen/typing.py +++ b/src/replit_river/codegen/typing.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import NewType, assert_never -TypeName = NewType("TypeName", str) ModuleName = NewType("ModuleName", str) ClassName = NewType("ClassName", str) FileContents = NewType("FileContents", str) @@ -10,30 +9,53 @@ RenderedPath = NewType("RenderedPath", str) -@dataclass +@dataclass(frozen=True) +class TypeName: + value: str + + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") + + +@dataclass(frozen=True) class DictTypeExpr: nested: "TypeExpression" + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") -@dataclass + +@dataclass(frozen=True) class ListTypeExpr: nested: "TypeExpression" + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") + -@dataclass +@dataclass(frozen=True) class LiteralTypeExpr: nested: int | str + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") + -@dataclass +@dataclass(frozen=True) class UnionTypeExpr: nested: list["TypeExpression"] + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") -@dataclass + +@dataclass(frozen=True) class OpenUnionTypeExpr: union: UnionTypeExpr + def __str__(self) -> str: + raise Exception("Complex type must be put through render_type_expr!") + TypeExpression = ( TypeName @@ -62,12 +84,16 @@ def render_type_expr(value: TypeExpression) -> str: "WrapValidator(translate_unknown_value)" "]" ) - case str(name): - return TypeName(name) + case TypeName(name): + return name case other: assert_never(other) +def render_literal_type(value: TypeExpression) -> str: + return render_type_expr(ensure_literal_type(value)) + + def extract_inner_type(value: TypeExpression) -> TypeName: match value: case DictTypeExpr(nested): @@ -84,7 +110,7 @@ def extract_inner_type(value: TypeExpression) -> TypeName: raise ValueError( f"Attempting to extract from a union, currently not possible: {value}" ) - case str(name): + case TypeName(name): return TypeName(name) case other: assert_never(other) @@ -112,7 +138,7 @@ def ensure_literal_type(value: TypeExpression) -> TypeName: raise ValueError( f"Unexpected expression when expecting a type name: {value}" ) - case str(name): + case TypeName(name): return TypeName(name) case other: assert_never(other)