Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,15 +571,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.solver().generalize_class_targs(cls.targs_mut());
}
let hint = None; // discard hint
let class_metadata = self.get_metadata_for_class(cls.class_object());
if let Some(ret) =
self.call_metaclass(&cls, arguments_range, args, keywords, errors, context, hint)
&& !self.is_compatible_constructor_return(&ret, cls.class_object())
{
if let Some(metaclass_dunder_call) = self.get_metaclass_dunder_call(&cls) {
if let Some(callee_range) = callee_range
&& let Some(metaclass) = self
.get_metadata_for_class(cls.class_object())
.custom_metaclass()
&& let Some(metaclass) = class_metadata.custom_metaclass()
{
self.record_external_attribute_definition_index(
&metaclass.clone().to_type(),
Expand Down Expand Up @@ -676,6 +675,17 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
self.record_resolved_trace(arguments_range, init_method);
}
if class_metadata.is_pydantic_base_model()
&& let Some(dataclass) = class_metadata.dataclass_metadata()
{
self.check_pydantic_argument_range_constraints(
cls.class_object(),
dataclass,
args,
keywords,
errors,
);
}
self.solver()
.finish_class_targs(cls.targs_mut(), self.uniques);
if let Some(mut ret) = dunder_new_ret {
Expand Down
2 changes: 1 addition & 1 deletion pyrefly/lib/alt/class/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
)
}

fn iter_fields(
pub(crate) fn iter_fields(
&self,
cls: &Class,
dataclass: &DataclassMetadata,
Expand Down
214 changes: 206 additions & 8 deletions pyrefly/lib/alt/class/pydantic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ use pyrefly_types::literal::Lit;
use pyrefly_types::types::Union;
use ruff_python_ast::Expr;
use ruff_python_ast::name::Name;
use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
use starlark_map::small_map::SmallMap;

use crate::alt::answers::LookupAnswer;
use crate::alt::answers_solver::AnswersSolver;
use crate::alt::callable::CallArg;
use crate::alt::callable::CallKeyword;
use crate::alt::solve::TypeFormContext;
use crate::alt::types::class_metadata::ClassMetadata;
use crate::alt::types::class_metadata::ClassSynthesizedField;
Expand All @@ -53,6 +57,52 @@ use crate::error::context::ErrorInfo;
use crate::types::class::Class;
use crate::types::types::Type;

fn int_literal_from_type(ty: &Type) -> Option<&LitInt> {
// We only currently enforce range constraints for literal ints.
match ty {
Type::Literal(Lit::Int(lit)) => Some(lit),
_ => None,
}
}

#[derive(Clone)]
struct PydanticRangeConstraints {
gt: Option<Type>,
ge: Option<Type>,
lt: Option<Type>,
le: Option<Type>,
}

impl PydanticRangeConstraints {
fn from_keywords(keywords: &DataclassFieldKeywords) -> Option<Self> {
if keywords.gt.is_none()
&& keywords.ge.is_none()
&& keywords.lt.is_none()
&& keywords.le.is_none()
{
return None;
}
Some(Self {
gt: keywords.gt.clone(),
ge: keywords.ge.clone(),
lt: keywords.lt.clone(),
le: keywords.le.clone(),
})
}
}

#[derive(Clone)]
struct PydanticParamConstraint {
field_name: Name,
constraints: PydanticRangeConstraints,
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum PydanticParamKey {
Position(usize),
Name(Name),
}

impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
pub fn get_pydantic_root_model_type_via_mro(
&self,
Expand Down Expand Up @@ -414,14 +464,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
range: TextRange,
errors: &ErrorCollector,
) {
fn int_literal_from_type(ty: &Type) -> Option<&LitInt> {
// We only currently enforce range constraints for literal defaults, so carve out
// the `Literal[int]` case and ignore everything else.
match ty {
Type::Literal(Lit::Int(lit)) => Some(lit),
_ => None,
}
}
let Some(default_ty) = &keywords.default else {
return;
};
Expand Down Expand Up @@ -499,4 +541,160 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
None
}

pub fn check_pydantic_argument_range_constraints(
&self,
cls: &Class,
dataclass: &DataclassMetadata,
args: &[CallArg],
keywords: &[CallKeyword],
errors: &ErrorCollector,
) {
let Some(constraints) = self.collect_pydantic_constraint_params(cls, dataclass) else {
return;
};

let infer_errors = self.error_swallower();
for (index, arg) in args.iter().enumerate() {
match arg {
CallArg::Arg(value) => {
let value_ty = value.infer(self, &infer_errors);
if let Some(info) = constraints.get(&PydanticParamKey::Position(index)) {
self.emit_pydantic_argument_constraint(
&value_ty,
info,
arg.range(),
errors,
);
}
}
CallArg::Star(..) => {
// Can't reliably map starred arguments to parameters.
break;
}
}
}

for kw in keywords {
let Some(identifier) = kw.arg.as_ref() else {
continue;
};
let key = PydanticParamKey::Name(identifier.id.clone());
if let Some(info) = constraints.get(&key) {
let value_ty = kw.value.infer(self, &infer_errors);
self.emit_pydantic_argument_constraint(&value_ty, info, kw.range, errors);
}
}
}

fn collect_pydantic_constraint_params(
&self,
cls: &Class,
dataclass: &DataclassMetadata,
) -> Option<SmallMap<PydanticParamKey, PydanticParamConstraint>> {
let mut constraints = SmallMap::new();
let mut position = 0;

for (field_name, _field, keywords) in self.iter_fields(cls, dataclass, true) {
if !keywords.init {
continue;
}

let constraint = PydanticRangeConstraints::from_keywords(&keywords);

if let Some(info) = constraint.as_ref().map(|c| PydanticParamConstraint {
field_name: field_name.clone(),
constraints: c.clone(),
}) {
if keywords.init_by_name {
constraints.insert(PydanticParamKey::Name(field_name.clone()), info.clone());
}
if let Some(alias) = &keywords.init_by_alias {
constraints.insert(PydanticParamKey::Name(alias.clone()), info.clone());
}
}

if keywords.init_by_name && !keywords.is_kw_only() {
if let Some(info) = constraint.as_ref() {
constraints.insert(
PydanticParamKey::Position(position),
PydanticParamConstraint {
field_name: field_name.clone(),
constraints: info.clone(),
},
);
}
position += 1;
}

if let Some(_alias) = &keywords.init_by_alias
&& !keywords.is_kw_only()
{
if let Some(info) = constraint.as_ref() {
constraints.insert(
PydanticParamKey::Position(position),
PydanticParamConstraint {
field_name: field_name.clone(),
constraints: info.clone(),
},
);
}
position += 1;
}
}

if constraints.is_empty() {
None
} else {
Some(constraints)
}
}

fn emit_pydantic_argument_constraint(
&self,
value_ty: &Type,
info: &PydanticParamConstraint,
range: TextRange,
errors: &ErrorCollector,
) {
let Some(value_lit) = int_literal_from_type(value_ty) else {
return;
};
let checks = [
("gt", info.constraints.gt.as_ref()),
("ge", info.constraints.ge.as_ref()),
("lt", info.constraints.lt.as_ref()),
("le", info.constraints.le.as_ref()),
];
for (label, constraint_ty) in checks {
let Some(constraint_ty) = constraint_ty else {
continue;
};
let Some(constraint_lit) = int_literal_from_type(constraint_ty) else {
continue;
};
let comparison = value_lit.cmp(constraint_lit);
let violates = match label {
"gt" => !matches!(comparison, std::cmp::Ordering::Greater),
"ge" => matches!(comparison, std::cmp::Ordering::Less),
"lt" => !matches!(comparison, std::cmp::Ordering::Less),
"le" => matches!(comparison, std::cmp::Ordering::Greater),
_ => false,
};
if violates {
self.error(
errors,
range,
ErrorInfo::Kind(ErrorKind::BadArgumentType),
format!(
"Argument value `{}` violates Pydantic `{}` constraint `{}` for field `{}`",
self.for_display(value_ty.clone()),
label,
self.for_display(constraint_ty.clone()),
info.field_name
),
);
}
}
}
}
74 changes: 71 additions & 3 deletions pyrefly/lib/test/pydantic/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,84 @@ use crate::test::util::TestEnv;
use crate::testcase;

pydantic_testcase!(
bug = "we could support ranges, but this is not for v1",
test_field_right_type,
r#"
from pydantic import BaseModel, Field
class Model(BaseModel):
x: int = Field(gt=0, lt=10)

Model(x=5)
Model(x=0)
Model(x=15)
Model(x=0) # E: Argument value `Literal[0]` violates Pydantic `gt` constraint `Literal[0]` for field `x`
Model(x=15) # E: Argument value `Literal[15]` violates Pydantic `lt` constraint `Literal[10]` for field `x`
"#,
);

pydantic_testcase!(
test_field_range_ge_le,
r#"
from pydantic import BaseModel, Field

class Model(BaseModel):
x: int = Field(ge=0, le=10)

Model(x=0)
Model(x=10)
Model(x=-1) # E: Argument value `Literal[-1]` violates Pydantic `ge` constraint `Literal[0]` for field `x`
Model(x=11) # E: Argument value `Literal[11]` violates Pydantic `le` constraint `Literal[10]` for field `x`
"#,
);

pydantic_testcase!(
test_field_range_positional,
r#"
from pydantic import BaseModel, Field

class Model(BaseModel):
x: int = Field(gt=0, kw_only=False)
y: int = Field(lt=3, kw_only=False)

Model(1, 2)
Model(0, 2) # E: Argument value `Literal[0]` violates Pydantic `gt` constraint `Literal[0]` for field `x`
Model(1, 3) # E: Argument value `Literal[3]` violates Pydantic `lt` constraint `Literal[3]` for field `y`
"#,
);

pydantic_testcase!(
test_field_range_kw_only,
r#"
from pydantic import BaseModel, Field

class Model(BaseModel):
x: int = Field(ge=1, kw_only=True)

Model(x=1)
Model(x=0) # E: Argument value `Literal[0]` violates Pydantic `ge` constraint `Literal[1]` for field `x`
"#,
);

pydantic_testcase!(
test_field_range_alias,
r#"
from pydantic import BaseModel, Field

class Model(BaseModel, validate_by_name=True, validate_by_alias=True):
x: int = Field(gt=0, validation_alias="y")

Model(x=0) # E: Argument value `Literal[0]` violates Pydantic `gt` constraint `Literal[0]` for field `x`
Model(y=0) # E: Argument value `Literal[0]` violates Pydantic `gt` constraint `Literal[0]` for field `x`
"#,
);

pydantic_testcase!(
test_field_range_alias_only,
r#"
from pydantic import BaseModel, Field

class Model(BaseModel, validate_by_name=False, validate_by_alias=True):
x: int = Field(gt=0, validation_alias="y")

Model(y=0) # E: Argument value `Literal[0]` violates Pydantic `gt` constraint `Literal[0]` for field `x`
Model(x=0) # E: Missing argument `y`
"#,
);

Expand Down