Skip to content

Commit 6e02870

Browse files
refactor(llm): simplify image processing with single-image transformation
- Replace image_preprocessor with image_transform_fn that works on single images - Add process_images() generic infrastructure in document_extraction_pipeline - Simplify extract_receipts_pipeline to only scale_receipt_image function - Remove nested functions and wrapper layers - Add image_output_dir parameter (separate from transform function) - Update README with simpler examples showing direct function references Benefits: - No nested functions or imports inside functions - Clear separation: generic I/O vs specific transformation - Easy to customize: just pass lambda img: img.rotate(90) - Simpler mental model: one function transforms one image
1 parent 467f572 commit 6e02870

File tree

3 files changed

+112
-69
lines changed

3 files changed

+112
-69
lines changed

llm/smart_data_extraction_llamaindex/README.md

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ class Invoice(BaseModel):
5656
total_amount: float = Field(description="Total amount")
5757

5858

59-
# 2. Optional: Define transformations
59+
# 2. Optional: Define data transformer
6060
def transform_invoice_data(df: pd.DataFrame) -> pd.DataFrame:
61+
"""Transform extracted invoice data."""
6162
df = df.copy()
6263
df["vendor_name"] = df["vendor_name"].str.upper()
6364
df["total_amount"] = pd.to_numeric(df["total_amount"], errors="coerce")
@@ -82,7 +83,7 @@ if __name__ == "__main__":
8283
output_cls=Invoice,
8384
prompt=INVOICE_PROMPT,
8485
id_column="invoice_id",
85-
transform_fn=transform_invoice_data,
86+
data_transformer=transform_invoice_data,
8687
)
8788

8889
print(result_df)
@@ -99,10 +100,9 @@ def extract_structured_data(
99100
prompt: str,
100101
id_column: str = "document_id",
101102
fields: Optional[List[str]] = None,
102-
preprocess: bool = False,
103-
output_dir: Optional[Path] = None,
104-
scale_factor: int = 3,
105-
transform_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
103+
image_transform_fn: Optional[Callable[[Image.Image], Image.Image]] = None,
104+
image_output_dir: Optional[Path] = None,
105+
data_transformer: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
106106
) -> pd.DataFrame
107107
```
108108

@@ -114,10 +114,9 @@ def extract_structured_data(
114114
**Optional Parameters:**
115115
- `id_column`: Document ID column name (default: "document_id")
116116
- `fields`: Fields to extract (default: all model fields)
117-
- `preprocess`: Enable image preprocessing (default: False)
118-
- `output_dir`: Directory for preprocessed images
119-
- `scale_factor`: Image scaling factor (default: 3)
120-
- `transform_fn`: Custom transformation function
117+
- `image_transform_fn`: Optional function to transform images (takes PIL Image, returns PIL Image)
118+
- `image_output_dir`: Directory to save transformed images (required if image_transform_fn provided)
119+
- `data_transformer`: Optional function to transform the extracted DataFrame
121120

122121
**Returns:**
123122
- `pd.DataFrame`: Extracted data
@@ -142,28 +141,30 @@ result = extract_structured_data(
142141
)
143142
```
144143

145-
### With Image Preprocessing
144+
### With Image Transformation
146145

147146
```python
148147
from pathlib import Path
149-
from extract_receipts_pipeline import Receipt
148+
from PIL import Image
149+
from extract_receipts_pipeline import Receipt, scale_receipt_image
150150

151151
result = extract_structured_data(
152152
image_paths=["low_res.jpg"],
153153
output_cls=Receipt,
154154
prompt="Extract receipt: {context_str}",
155-
preprocess=True,
156-
output_dir=Path("processed_images"),
157-
scale_factor=3,
155+
image_transform_fn=scale_receipt_image, # Simple function reference
156+
image_output_dir=Path("processed_images"),
158157
)
159158
```
160159

161-
### With Custom Transformations
160+
### With Data Transformation
162161

163162
```python
164163
import pandas as pd
165164

166-
def clean_data(df: pd.DataFrame) -> pd.DataFrame:
165+
def transform_form_data(df: pd.DataFrame) -> pd.DataFrame:
166+
"""Clean and normalize extracted form data."""
167+
df = df.copy()
167168
df["name"] = df["name"].str.title()
168169
df["email"] = df["email"].str.lower()
169170
return df
@@ -172,7 +173,7 @@ result = extract_structured_data(
172173
image_paths=["form.pdf"],
173174
output_cls=FormData,
174175
prompt="Extract: {context_str}",
175-
transform_fn=clean_data,
176+
data_transformer=transform_form_data,
176177
)
177178
```
178179

@@ -182,20 +183,36 @@ To create a new document extractor (like the receipt pipeline):
182183

183184
1. Import the generic `extract_structured_data` function from `document_extraction_pipeline`
184185
2. Define your Pydantic schema(s)
185-
3. (Optional) Create transformation function
186-
4. Define extraction prompt
187-
5. Add `__main__` block with example usage
186+
3. (Optional) Create `image_transform_fn` - a simple function that transforms one PIL Image
187+
4. (Optional) Create `data_transformer` function for data transformation
188+
5. Define extraction prompt
189+
6. Add `__main__` block with example usage
190+
191+
**Example image transformation:**
192+
```python
193+
from PIL import Image
194+
195+
def rotate_and_scale(img: Image.Image) -> Image.Image:
196+
"""Custom transformation: rotate 90 degrees and scale up."""
197+
rotated = img.rotate(90, expand=True)
198+
new_size = (rotated.width * 2, rotated.height * 2)
199+
return rotated.resize(new_size, Image.Resampling.LANCZOS)
200+
```
188201

189202
See [extract_receipts_pipeline.py](extract_receipts_pipeline.py) for a complete example.
190203

191204
## Dependencies
192205

193-
Both files include uv inline script dependencies. Required packages:
206+
### Generic Pipeline
207+
Required packages (in `document_extraction_pipeline.py`):
194208
- llama-index
195209
- llama-index-program-openai
196210
- llama-parse
197211
- python-dotenv
198212
- pandas
199-
- pillow
213+
214+
### Receipt Pipeline
215+
Additional packages (in `extract_receipts_pipeline.py`):
216+
- pillow (for image preprocessing)
200217

201218
Run with `uv run <script_name>.py` - dependencies will be automatically installed.

llm/smart_data_extraction_llamaindex/document_extraction_pipeline.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,39 @@ def configure_settings() -> None:
3434
Settings.context_window = 8000
3535

3636

37-
def scale_image(image_path: Path, output_dir: Path, scale_factor: int = 3) -> Path:
38-
"""Scale up an image using high-quality resampling.
37+
def process_images(
38+
image_paths: List[str],
39+
output_dir: Path,
40+
transform_image_fn: Callable[[Image.Image], Image.Image]
41+
) -> List[str]:
42+
"""Process images by applying a transformation function.
43+
44+
Generic infrastructure that loads images, applies transformation, and saves them.
3945
4046
Args:
41-
image_path: Path to the original image
42-
output_dir: Directory to save the scaled image
43-
scale_factor: Factor to scale up the image (default: 3x)
47+
image_paths: List of paths to images
48+
output_dir: Directory to save processed images
49+
transform_image_fn: Function that takes PIL Image and returns transformed PIL Image
4450
4551
Returns:
46-
Path to the scaled image
52+
List of paths to processed images
4753
"""
48-
# Load the image
49-
img = Image.open(image_path)
54+
output_dir.mkdir(parents=True, exist_ok=True)
55+
processed_paths = []
5056

51-
# Scale up the image using high-quality resampling
52-
new_size = (img.width * scale_factor, img.height * scale_factor)
53-
img_resized = img.resize(new_size, Image.Resampling.LANCZOS)
57+
for path in image_paths:
58+
# Load image
59+
img = Image.open(path)
5460

55-
# Save to output directory with same filename
56-
output_dir.mkdir(parents=True, exist_ok=True)
57-
output_path = output_dir / image_path.name
58-
img_resized.save(output_path, quality=95)
61+
# Apply transformation
62+
img_transformed = transform_image_fn(img)
63+
64+
# Save transformed image
65+
output_path = output_dir / Path(path).name
66+
img_transformed.save(output_path, quality=95)
67+
processed_paths.append(str(output_path))
5968

60-
return output_path
69+
return processed_paths
6170

6271

6372
def extract_documents(
@@ -111,15 +120,15 @@ def create_extracted_df(
111120
records: List[dict],
112121
id_column: str,
113122
fields: List[str],
114-
transform_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None
123+
data_transformer: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None
115124
) -> pd.DataFrame:
116125
"""Create DataFrame from extracted records.
117126
118127
Args:
119128
records: List of extraction results with id and data
120129
id_column: Column name for document IDs
121130
fields: List of field names to extract from the Pydantic model
122-
transform_fn: Optional function to transform the DataFrame
131+
data_transformer: Optional function to transform the DataFrame
123132
124133
Returns:
125134
DataFrame with extracted fields
@@ -134,8 +143,8 @@ def create_extracted_df(
134143
]
135144
)
136145

137-
if transform_fn:
138-
df = transform_fn(df)
146+
if data_transformer:
147+
df = data_transformer(df)
139148

140149
return df
141150

@@ -146,10 +155,9 @@ def extract_structured_data(
146155
prompt: str,
147156
id_column: str = "document_id",
148157
fields: Optional[List[str]] = None,
149-
preprocess: bool = False,
150-
output_dir: Optional[Path] = None,
151-
scale_factor: int = 3,
152-
transform_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
158+
image_transform_fn: Optional[Callable[[Image.Image], Image.Image]] = None,
159+
image_output_dir: Optional[Path] = None,
160+
data_transformer: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
153161
) -> pd.DataFrame:
154162
"""Extract structured data from documents using a generic pipeline.
155163
@@ -159,10 +167,9 @@ def extract_structured_data(
159167
prompt: Extraction prompt template (must include {context_str})
160168
id_column: Column name for document identifiers
161169
fields: List of field names to extract (if None, uses all model fields)
162-
preprocess: Whether to scale/preprocess images
163-
output_dir: Directory for preprocessed images
164-
scale_factor: Image scaling factor if preprocessing
165-
transform_fn: Optional transformation function for DataFrames
170+
image_transform_fn: Optional function to transform individual images (takes PIL Image, returns PIL Image)
171+
image_output_dir: Directory to save transformed images (required if image_transform_fn provided)
172+
data_transformer: Optional function to transform the extracted DataFrame
166173
167174
Returns:
168175
DataFrame with extracted data
@@ -173,22 +180,19 @@ def extract_structured_data(
173180
if fields is None:
174181
fields = list(output_cls.model_fields.keys())
175182

176-
# Preprocess images if requested
177-
if preprocess:
178-
if output_dir is None:
179-
raise ValueError("output_dir must be provided when preprocess=True")
180-
print("Preprocessing images...")
181-
paths_to_parse = [
182-
scale_image(Path(p), output_dir, scale_factor=scale_factor)
183-
for p in image_paths
184-
]
183+
# Process images if transformation function provided
184+
if image_transform_fn:
185+
if image_output_dir is None:
186+
raise ValueError("image_output_dir must be provided when image_transform_fn is specified")
187+
print("Processing images...")
188+
paths_to_parse = process_images(image_paths, image_output_dir, image_transform_fn)
185189
else:
186190
paths_to_parse = image_paths
187191

188192
# Extract documents
189193
structured_data = extract_documents(paths_to_parse, prompt, id_column, output_cls)
190194

191195
# Create extracted DataFrame
192-
extracted_df = create_extracted_df(structured_data, id_column, fields, transform_fn)
196+
extracted_df = create_extracted_df(structured_data, id_column, fields, data_transformer)
193197

194198
return extracted_df

llm/smart_data_extraction_llamaindex/extract_receipts_pipeline.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pandas as pd
1919
from document_extraction_pipeline import extract_structured_data
20+
from PIL import Image
2021
from pydantic import BaseModel, Field
2122

2223

@@ -46,9 +47,31 @@ class Receipt(BaseModel):
4647
items: List[ReceiptItem] = Field(default_factory=list)
4748

4849

49-
# Receipt-specific transformations
50-
def transform_receipt_columns(df: pd.DataFrame) -> pd.DataFrame:
51-
"""Apply receipt-specific transformations."""
50+
# Receipt-specific image transformation
51+
def scale_receipt_image(img: Image.Image, scale_factor: int = 3) -> Image.Image:
52+
"""Scale up a receipt image for better OCR.
53+
54+
Args:
55+
img: PIL Image object
56+
scale_factor: Factor to scale up the image (default: 3x)
57+
58+
Returns:
59+
Transformed PIL Image
60+
"""
61+
new_size = (img.width * scale_factor, img.height * scale_factor)
62+
return img.resize(new_size, Image.Resampling.LANCZOS)
63+
64+
65+
# Receipt-specific data transformations
66+
def transform_receipt_data(df: pd.DataFrame) -> pd.DataFrame:
67+
"""Transform extracted receipt data (normalize text, convert types).
68+
69+
Args:
70+
df: DataFrame with extracted receipt data
71+
72+
Returns:
73+
Transformed DataFrame
74+
"""
5275
df = df.copy()
5376
df["company"] = df["company"].str.upper()
5477
df["total"] = pd.to_numeric(df["total"], errors="coerce")
@@ -78,17 +101,16 @@ def transform_receipt_columns(df: pd.DataFrame) -> pd.DataFrame:
78101
num_receipts = 10
79102
receipt_paths = sorted(receipt_dir.glob("*.jpg"))[:num_receipts]
80103

81-
# Run the pipeline
104+
# Run the pipeline - pass transformation function directly
82105
result_df = extract_structured_data(
83106
image_paths=receipt_paths,
84107
output_cls=Receipt,
85108
prompt=RECEIPT_PROMPT,
86109
id_column="receipt_id",
87110
fields=["company", "total", "purchase_date"],
88-
preprocess=True,
89-
output_dir=adjusted_receipt_dir,
90-
scale_factor=3,
91-
transform_fn=transform_receipt_columns,
111+
image_transform_fn=scale_receipt_image,
112+
image_output_dir=adjusted_receipt_dir,
113+
data_transformer=transform_receipt_data,
92114
)
93115

94116
print("\nExtraction complete!")

0 commit comments

Comments
 (0)