Skip to content

fix Support functools.partial #3175#3178

Open
asukaminato0721 wants to merge 3 commits intofacebook:mainfrom
asukaminato0721:3175
Open

fix Support functools.partial #3175#3178
asukaminato0721 wants to merge 3 commits intofacebook:mainfrom
asukaminato0721:3175

Conversation

@asukaminato0721
Copy link
Copy Markdown
Contributor

Summary

Fixes #3175

Implemented functools.partial support for concrete callable signatures.

When Pyrefly sees functools.partial applied to a normal list-parameter callable, it now rewrites the result to the remaining callable signature and validates the bound arguments up front.

cases like partial(f, 1) behave like a one-argument callable and over-applied constructors like partial(f, 1, "a", 2, "b", 3, "c", 4, "d") report an arity error.

Test Plan

add test

@meta-cla meta-cla Bot added the cla signed label Apr 19, 2026
@asukaminato0721 asukaminato0721 marked this pull request as ready for review April 19, 2026 10:38
Copilot AI review requested due to automatic review settings April 19, 2026 10:38
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds concrete signature support for functools.partial so Pyrefly can (a) validate bound arguments at partial-construction time and (b) type the resulting value as a callable with the remaining (unbound) parameters.

Changes:

  • Extend constructor-call inference to detect functools.partial(...) and rewrite its resulting type to a Callable with the remaining parameter list.
  • Validate bound positional/keyword arguments against the target callable signature and emit arity/type errors early.
  • Add tests covering remaining-signature preservation and rejection of over-binding.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
pyrefly/lib/alt/call.rs Implements functools.partial special-casing in constructor calls: extracts a concrete callable signature, binds arguments, and returns a rewritten callable type.
pyrefly/lib/test/callable.rs Adds regression tests to ensure partial preserves remaining arity/types and errors on too many bound arguments.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread pyrefly/lib/alt/call.rs Outdated
self.error(
errors,
kw.range(),
ErrorInfo::Kind(ErrorKind::BadArgumentCount),
Copy link

Copilot AI Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error kind for an unexpected keyword in partial(...) is reported as BadArgumentCount, but elsewhere call-site keyword issues use ErrorKind::UnexpectedKeyword (see pyrefly/lib/alt/callable.rs where the message "Unexpected keyword argument ..." is emitted with UnexpectedKeyword). Using BadArgumentCount here makes the diagnostic category/suppression code inconsistent. Switch this branch to emit ErrorKind::UnexpectedKeyword (and keep the message as-is).

Suggested change
ErrorInfo::Kind(ErrorKind::BadArgumentCount),
ErrorInfo::Kind(ErrorKind::UnexpectedKeyword),

Copilot uses AI. Check for mistakes.
@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@github-actions github-actions Bot added size/l and removed size/l labels Apr 19, 2026
@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@github-actions github-actions Bot added size/l and removed size/l labels Apr 19, 2026
@github-actions

This comment has been minimized.

@github-actions

This comment has been minimized.

@github-actions
Copy link
Copy Markdown

Diff from mypy_primer, showing the effect of this PR on open source code:

jax (https://github.com/google/jax)
+ ERROR jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py:2576:7-11: Argument `Array | MultiHeadMask` is not assignable to parameter `mask` with type `Array` [bad-argument-type]
+ ERROR jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py:2592:11-15: Argument `Array | MultiHeadMask` is not assignable to parameter `mask` with type `Array` [bad-argument-type]
+ ERROR jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py:2593:11-26: Argument `tuple[int | Unknown | None, int | Unknown | None]` is not assignable to parameter `block_shape` with type `tuple[int, int]` [bad-argument-type]
+ ERROR jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py:2602:9-13: Argument `Array | MultiHeadMask` is not assignable to parameter `mask` with type `Array` [bad-argument-type]
+ ERROR jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py:2603:9-26: Argument `tuple[int | Unknown | None, int | Unknown | None]` is not assignable to parameter `block_shape` with type `tuple[int, int]` [bad-argument-type]

scipy (https://github.com/scipy/scipy)
- ERROR scipy/fft/_duccfft/basic.py:35:1-13: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:37:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:65:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:67:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:99:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:101:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:153:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:155:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:181:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:183:1-16: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:222:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:224:1-16: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:249:1-22: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/basic.py:251:1-23: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:49:1-13: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:51:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:54:1-13: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:56:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:102:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:104:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:107:1-14: Object of class `partial` has no attribute `__name__` [missing-attribute]
- ERROR scipy/fft/_duccfft/realtransforms.py:109:1-15: Object of class `partial` has no attribute `__name__` [missing-attribute]

pytest-autoprofile (https://gitlab.com/TTsangSC/pytest-autoprofile)
+ ERROR tests/test_doctest.py:184:30-35: Argument `CheckPytestConfig | Unknown` is not assignable to parameter `*args` with type `str` [bad-argument-type]

hydra-zen (https://github.com/mit-ll-responsible-ai/hydra-zen)
-  INFO tests/annotations/declarations.py:160:16-73: revealed type: Just[partial[int]] [reveal-type]
+  INFO tests/annotations/declarations.py:160:16-73: revealed type: Just[(() -> int) & partial[int]] [reveal-type]
-  INFO tests/annotations/declarations.py:166:16-58: revealed type: partial[int] [reveal-type]
+  INFO tests/annotations/declarations.py:166:16-58: revealed type: (() -> int) & partial[int] [reveal-type]
- ERROR tests/annotations/declarations.py:471:23-33: `partial[int]` is not assignable to `Partial[int]` [bad-assignment]
+ ERROR tests/annotations/declarations.py:471:23-33: `(() -> int) & partial[int]` is not assignable to `Partial[int]` [bad-assignment]
- ERROR tests/annotations/declarations.py:472:24-41: `partial[bool]` is not assignable to `Partial[bool]` [bad-assignment]
+ ERROR tests/annotations/declarations.py:472:24-41: `((*, x: str = ...) -> bool) & partial[bool]` is not assignable to `Partial[bool]` [bad-assignment]

spack (https://github.com/spack/spack)
- ERROR lib/spack/spack/solver/reuse.py:38:31-38: Argument `partial[list[Spec] | None]` is not assignable to parameter `factory` with type `() -> list[Spec]` in function `spack.spec_filter.SpecFilter.__init__` [bad-argument-type]
+ ERROR lib/spack/spack/solver/reuse.py:38:31-38: Argument `((*, configuration: Unknown = ...) -> list[Spec] | None) & partial[list[Spec] | None]` is not assignable to parameter `factory` with type `() -> list[Spec]` in function `spack.spec_filter.SpecFilter.__init__` [bad-argument-type]

prefect (https://github.com/PrefectHQ/prefect)
- ERROR src/prefect/task_engine.py:639:9-35: Object of class `partial` has no attribute `log_on_run` [missing-attribute]
- ERROR src/prefect/task_engine.py:1263:9-35: Object of class `partial` has no attribute `log_on_run` [missing-attribute]

@github-actions
Copy link
Copy Markdown

Primer Diff Classification

✅ 4 improvement(s) | ➖ 2 neutral | 6 project(s) total | +11, -29 errors

4 improvement(s) across jax, scipy, pytest-autoprofile, prefect.

Project Verdict Changes Error Kinds Root Cause
jax ✅ Improvement +5 bad-argument-type pyrefly/lib/alt/call.rs
scipy ✅ Improvement -22 missing-attribute functools_partial_callable()
pytest-autoprofile ✅ Improvement +1 bad-argument-type functools_partial_callable()
hydra-zen ➖ Neutral +4, -4 bad-assignment, reveal-type
spack ➖ Neutral +1, -1 bad-argument-type
prefect ✅ Improvement -2 missing-attribute functools_partial_callable()
Detailed analysis

✅ Improvement (4)

jax (+5)

These are genuine type issues in the JAX codebase that pyrefly (and mypy/pyright) correctly identify:

  1. mask type mismatch (lines 2576, 2592, 2602): After the isinstance(mask, np.ndarray) check on line 2555 converts np.ndarray masks to MultiHeadMask, and after the isinstance(mask, jax.Array) check on line 2565 selects which function to assign to process_mask_fn, the variable mask has type jax.Array | MultiHeadMask. The process_mask_fn variable is assigned either process_dynamic_mask (which expects jax.Array for mask) or process_mask (which expects MultiHeadMask for mask). However, the type checker cannot correlate the conditional assignment of process_mask_fn with the narrowed type of mask — it sees process_mask_fn as a union of both function types, and mask as jax.Array | MultiHeadMask. When calling process_mask_fn(mask, ...), neither function signature accepts the full union type. The code works at runtime because the isinstance(mask, jax.Array) check ensures the correct function is paired with the correct mask type, but the type system can't prove this correlation.

  2. block_shape type mismatch (lines 2593, 2603): bq_dq and bkv_dq are typed as int | None (from BlockSizes fields block_q_dq and block_kv_dq), so (bq_dq, bkv_dq) is tuple[int | None, int | None]. The function expects tuple[int, int]. Similarly, bq_dkv and bkv_dkv (from block_q_dkv and block_kv_dkv) have the same issue on line 2603. At runtime, has_backward_blocks guarantees these are not None when this code path is reached, but the type checker can't narrow through the property check.

These are real type-level issues that pyrefly (and mypy/pyright) correctly identify. The PR's functools.partial support enabled pyrefly to see through the partial calls at lines 2628-2636 and validate the arguments, which is why these errors now appear.

Attribution: The PR implements functools.partial support in pyrefly/lib/alt/call.rs. The new functools_partial_callable method and bind_partial_callable method now rewrite the callable signature when functools.partial is used. This means that at lines 2628-2636, where partial(_make_splash_attention, ...) is used, pyrefly now understands the resulting callable's signature and can validate arguments passed to it. Previously, pyrefly likely treated partial(...) results as opaque. Now it can see through the partial calls and validate the arguments against the remaining parameters of the wrapped function. The process_mask_fn variable is assigned via a conditional (lines 2563-2567), and when pyrefly resolves the partial call, it can now check the argument types against the concrete function signatures of process_mask and process_dynamic_mask, which expect Array for mask and tuple[int, int] for block_shape.

scipy (-22)

All 22 removed errors were false positives where pyrefly incorrectly reported that functools.partial objects lack a __name__ attribute. While partial objects do NOT have __name__ as a declared attribute (it is not in typeshed, and accessing it without first setting it raises AttributeError at runtime), partial objects do have a __dict__ and support arbitrary attribute assignment. The code in question is setting __name__ (e.g., fft.__name__ = 'fft'), which is valid Python. The PR's improved functools.partial support now correctly allows such attribute assignments, removing these 22 false positives with no new errors introduced.
Attribution: The functools_partial_callable() method in pyrefly/lib/alt/call.rs creates an intersection type between the partial instance type and the rewritten callable signature. The key lines are: let partial_instance = result.clone(); result = self.heap.mk_intersect(vec![partial_instance, partial_callable.clone()], partial_callable);. By preserving the partial_instance in the intersection, the resulting type retains access to partial's own attributes (like __name__, args, keywords, func), which eliminates the false missing-attribute errors.

pytest-autoprofile (+1)

The error occurs at tests/test_doctest.py:184:30-35 where CheckPytestConfig | Unknown is passed where str is expected. The Unknown component in the type indicates the value comes from code without complete type annotations, meaning the type checker cannot fully resolve the type. Pyright also flags this location (co-reported), while mypy does not. The fact that pyright independently flags this same location provides some evidence this could be a real type issue, though mypy's silence suggests it may also be a matter of type checker strictness differences. The Unknown in the union type means pyrefly cannot prove the value is a str, which is technically correct type-checking behavior — if the type is CheckPytestConfig | Unknown, it is not guaranteed to be str. However, it's also possible that at runtime the value is always a str and the Unknown arises from incomplete type information rather than a genuine bug. The PR's functools.partial support may have changed how certain callable signatures are tracked, potentially affecting type inference in this area, though without seeing the specific code changes it's difficult to confirm the exact mechanism. This is likely a case where improved type inference is producing a more precise (and stricter) diagnostic rather than necessarily catching a previously hidden bug.
Attribution: The change in pyrefly/lib/alt/call.rs in the functools_partial_callable() method and the surrounding constructor logic. The PR adds functools.partial support by intercepting calls to functools.partial in the constructor path, extracting the callable signature, binding partial arguments, and producing a new callable type. The key change is in the constructor resolution where partial_callable is computed and then intersected with the partial instance type via self.heap.[mk_intersect()](https://github.com/facebook/pyrefly/blob/main/pyrefly/lib/alt/call.rs). This new partial support likely changed how functools.partial instances are typed, which could affect downstream type checking when partial objects are called.

prefect (-2)

The two removed errors flagged handle_rollback.log_on_run = False where handle_rollback is a functools.partial object. While partial in typeshed doesn't declare log_on_run, Python's partial objects support arbitrary attribute assignment at runtime via __dict__. The errors were technically correct from a strict typing perspective (and mypy/pyright would also flag them), but they were false positives in practice since the code works fine at runtime. The PR's new functools.partial support changes the inferred type to an intersection type, which incidentally suppresses these errors. The net effect is removing 2 false-positive-in-practice errors, which is a minor improvement.
Attribution: The changes in pyrefly/lib/alt/call.rs, specifically the functools_partial_callable() method and the intersection type creation in the class instantiation path (lines around 992-998), changed the inferred type of partial(self.handle_rollback) from a plain partial[None] to an intersection of partial[None] and a callable type. This intersection type apparently doesn't trigger the missing-attribute check for log_on_run, causing these errors to disappear.

➖ Neutral (2)

hydra-zen (+4, -4)

Same errors at same locations with same error kinds — message wording changed, no behavioral impact.

spack (+1, -1)

Same errors at same locations with same error kinds — message wording changed, no behavioral impact.


Was this helpful? React with 👍 or 👎

Classification by primer-classifier (2 heuristic, 4 LLM)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support functools.partial

3 participants