[SPARK-57521][ML][CONNECT] Exclude parent from Model.estimatedSize to fix overcounting in ML cache#56584
[SPARK-57521][ML][CONNECT] Exclude parent from Model.estimatedSize to fix overcounting in ML cache#56584mkincaid wants to merge 2 commits into
Conversation
|
Fixed actions configuration on my fork. Closing and reopening to trigger the checks to rerun |
| // shared SparkSession state as part of every model's size. | ||
| // The parent is @transient (not persisted) and is not needed for transform() or save(). | ||
| val savedParent = parent | ||
| parent = null |
There was a problem hiding this comment.
Please investigate and address a possible thread-safety regression here. There is a side-effecting mutation of shared state in a base-class method. Two concurrent estimatedSize calls on the same model can interleave so both save then both restore, with the second finally clobbering parent to null permanently; a concurrent reader (hasParent, transform, save) can also observe parent == null during the window.
In the current path this is masked because estimatedSize is invoked inside MLCache.register (which is synchronized) on a freshly-fit, not-yet-shared model, so it is not an active production bug today, but estimatedSize is private[spark] and the previous implementation was side-effect free, so the new contract is strictly weaker.
Please consider a non-mutating approach rather than mutating shared instance state.
There was a problem hiding this comment.
Hi @uros-b, thanks for the quick review and input. I pushed a change that adds synchronized so that we wouldn't have two concurrent estimatedSize calls from here. However, I'm realizing this doesn't address the second part of your comment (a concurrent reader from elsewhere would still observe parent == null).
As I looked into this further, the truly non-mutating approaches I came up with were:
- Create a
copyof theModelwith emptyparent, then size that. But this depends on the implementation of thecopymethod which is model-specific (so not sure if it can be relied on to faithfully copy everything we care about sizing). - Make the
ModelobjectCloneable, thenclone(), clearparent, and size. But changing an interface ofModelitself seems less conservative and beyond the scope I was intending for this original fix. - Serialize and deserialize the object before estimating its size. Since the
parentis@transientit would be gone in the serialized copy. This seems conceptually appealing (it seems like, in principle, the data the model keeps and serializes is the state we care about sizing) but not sure if it might be expensive for large models and, like thecopyoption, should I worry about the possibility something relevant doesn’t survive the round trip. - Target the fix elsewhere, e.g., perhaps
SizeEstimatoritself should skip walking throughSparkSessionobjects (the same way as there are existing exclusions there forClassLoaderandscala.reflect). This also seems less conservative since other users ofSizeEstimatormight not want the behavior to change.
Or I may be missing something easier/cleaner. It is probably pretty obvious that I'm new to this code base, so I want to be thoughtful about design and get more input before proceeding. Appreciate your patience with me and looking forward to your thoughts :)
uros-b
left a comment
There was a problem hiding this comment.
Thank you @mkincaid. Left a few comments, passing on to @zhengruifeng and @WeichenXu123 (ML/Connect experts) for further review.
… fix overcounting in ML cache
2f6bb9e to
22471c8
Compare
What changes were proposed in this pull request?
This patch unsets
parentbefore calling theSizeEstimator.Why are the changes needed?
Currently
SizeEstimatorincludes the size of theSparkSessionbecause it traverses theparentobject which (in the case of many estimators that use DataFrame operations when fitting, likeStringIndexer) eventually refers to the session. The session is there anyway and its size isn't attributable to fitting this specific model (and this results in double-counting when more models are fit), so it shouldn't be included in the size estimate.The impact of the bug is largest when the
SparkSessionis large. For example, in Databricks, my testing shows that a 300-800M SparkSession is typical. In some configurations, like Databricks serverless, the size limit for a single model object might be 256M, so this bug causes such models to fail to train regardless of the state of the cache otherwise.The Jira ticket includes a simple script that reproduces the condition locally, though the session is much smaller in that case (maybe 300k).
Does this PR introduce any user-facing change?
Yes, a favorable one, in that the model cache would fill less quickly (and the reported sizes of cached models would be smaller, if they are among the affected models).
How was this patch tested?
A test is added: training a
StringIndexershould estimate at no larger than 50k, in the trivial test case with 3 strings. This test fails before the patch and passes after it. Another similar test is provided forMinMaxScaler. AModelSuiteis added to hold these since the bug is at theModellevel, not that of individual models (so theStringIndexerandMinMaxScalersuites aren't really the right place for these tests, although they are examples).Was this patch authored or co-authored using generative AI tooling?
Yes, the bug was discovered and initial patch/tests were created by pair programming with Claude. I wrote the bug/docs myself and validated the approach and final patch.
Generated-by: Claude Opus 4.6