Skip to content

Commit 9d459cf

Browse files
bobrenjc93pobin6
authored andcommitted
Add trunc to z3 validator (pytorch#140886)
Fixes vision_maskrcnn benchmark when validation is turned on Pull Request resolved: pytorch#140886 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#140830, pytorch#140832, pytorch#140828
1 parent 1e2d36a commit 9d459cf

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch/fx/experimental/validator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
191191
self.floor(number) < number, self.floor(number + 1), number
192192
) # type: ignore[return-value]
193193

194+
def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
195+
return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]
196+
194197
def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
195198
return z3.If(a > b, a, b) # type: ignore[return-value]
196199

@@ -291,6 +294,7 @@ def wrapper(*args):
291294
# Math module.
292295
math.ceil: lift(ops.ceil),
293296
math.floor: lift(ops.floor),
297+
math.trunc: lift(ops.trunc),
294298
# Torch module.
295299
torch.sym_float: lift(ops.to_real),
296300
torch.sym_max: lift(ops.max),

0 commit comments

Comments
 (0)