From db3a92010d8fddfe8aea9dca4b3f97daae5796b2 Mon Sep 17 00:00:00 2001 From: yinjie Date: Sun, 14 Sep 2025 13:49:48 -0400 Subject: [PATCH 01/17] feat: basic parameter subclass --- traincheck/proxy_wrapper/subclass.py | 76 ++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 traincheck/proxy_wrapper/subclass.py diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py new file mode 100644 index 00000000..a7ea3539 --- /dev/null +++ b/traincheck/proxy_wrapper/subclass.py @@ -0,0 +1,76 @@ +import copy +import logging +import os +import threading +import time +import types +from typing import Dict + +import torch +from torch import nn + +import traincheck.config.config as general_config +import traincheck.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables +import traincheck.proxy_wrapper.proxy_methods as proxy_methods +from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars +from traincheck.utils import get_timestamp_ns, typename + +from .dumper import json_dumper as dumper +from .proxy_basics import unproxy_arg, unproxy_args_kwargs +from .proxy_handler import PROXY_SUPPORT_OBJ_TYPES + +# from .proxy_registry import get_global_registry +from .utils import print_debug + + +def _is_fake_like(t: torch.Tensor) -> bool: + try: + from torch._subclasses.fake_tensor import FakeTensor + + if isinstance(t, FakeTensor): + return True + except Exception: + pass + if getattr(t, "fake_mode", None) is not None: + return True + if getattr(t, "_is_fake", False): + return True + ty = type(t) + if "fake" in ( + getattr(ty, "__name__", "").lower() + getattr(ty, "__module__", "").lower() + ): + return True + return False + + +class ProxyParameter(torch.nn.Parameter): + def __new__(cls, data, var_name=""): + if isinstance(data, ProxyParameter): + return data + + return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad) + + def __init__(self, data, var_name=""): + self.__dict__["varname"] = var_name + print(f"init: {self.varname}") + super().__init__() + + def __setattr__(self, name, value): + print(f"paremeter: {self.varname}, name = {name}, value = {value}") + + return super().__setattr__(name, value) + + def __deepcopy__(self, memo): + data = self.data + return type(self)( + data.clone(memory_format=torch.preserve_format), + self.requires_grad, + var_name=self.varname, + ) + + +def proxy_parameter(module: nn.Module, parent_name: str = ""): + for name, t in list(module.named_parameters(recurse=False)): + module._parameters[name] = ProxyParameter(t, parent_name + "." + name) + for name, child in module.named_children(): + proxy_parameter(child, parent_name + "." + name) From 226a16b29d741f1ca7a0f10fbc9eafd8d2cc34c8 Mon Sep 17 00:00:00 2001 From: yinjie Date: Sun, 14 Sep 2025 15:52:30 -0400 Subject: [PATCH 02/17] fix: avoid proxy or copy the unsuitable object --- traincheck/proxy_wrapper/subclass.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index a7ea3539..9a62b422 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -45,7 +45,7 @@ def _is_fake_like(t: torch.Tensor) -> bool: class ProxyParameter(torch.nn.Parameter): def __new__(cls, data, var_name=""): - if isinstance(data, ProxyParameter): + if not isinstance(data, torch.nn.Parameter): return data return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad) @@ -62,6 +62,11 @@ def __setattr__(self, name, value): def __deepcopy__(self, memo): data = self.data + if not isinstance(data, ProxyParameter): + return torch.nn.Parameter( + data.clone(memory_format=torch.preserve_format), + self.requires_grad, + ) return type(self)( data.clone(memory_format=torch.preserve_format), self.requires_grad, From 7fd27d836742567f01df398ac1f08cf71176ecf2 Mon Sep 17 00:00:00 2001 From: yinjie Date: Sun, 14 Sep 2025 16:09:54 -0400 Subject: [PATCH 03/17] fix: deepcopy & avoid proxy during dynamo --- traincheck/proxy_wrapper/subclass.py | 44 ++++++++++++++++------------ 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 9a62b422..6d9775d1 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -23,29 +23,39 @@ from .utils import print_debug -def _is_fake_like(t: torch.Tensor) -> bool: +def in_dynamo() -> bool: try: - from torch._subclasses.fake_tensor import FakeTensor + import torch._dynamo as dynamo - if isinstance(t, FakeTensor): + return bool(dynamo.is_compiling()) + except Exception: + return False + + +def is_fake_tensor(x: torch.Tensor) -> bool: + try: + from torch._subclasses.fake_tensor import FakeTensor # 2.x + + if isinstance(x, FakeTensor): return True except Exception: pass - if getattr(t, "fake_mode", None) is not None: - return True - if getattr(t, "_is_fake", False): + if getattr(x, "fake_mode", None) is not None: return True - ty = type(t) - if "fake" in ( - getattr(ty, "__name__", "").lower() + getattr(ty, "__module__", "").lower() - ): + if getattr(x, "_is_fake", False): return True - return False + + return isinstance(x, torch.Tensor) and x.device.type == "meta" class ProxyParameter(torch.nn.Parameter): def __new__(cls, data, var_name=""): - if not isinstance(data, torch.nn.Parameter): + if in_dynamo() or is_fake_tensor(data): + if isinstance(data, nn.Parameter): + return data + return nn.Parameter(data, requires_grad=data.requires_grad) + + if isinstance(data, ProxyParameter): return data return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad) @@ -62,19 +72,17 @@ def __setattr__(self, name, value): def __deepcopy__(self, memo): data = self.data - if not isinstance(data, ProxyParameter): - return torch.nn.Parameter( - data.clone(memory_format=torch.preserve_format), - self.requires_grad, - ) + if in_dynamo() or is_fake_tensor(self): + return self return type(self)( data.clone(memory_format=torch.preserve_format), - self.requires_grad, var_name=self.varname, ) def proxy_parameter(module: nn.Module, parent_name: str = ""): + if in_dynamo(): + return for name, t in list(module.named_parameters(recurse=False)): module._parameters[name] = ProxyParameter(t, parent_name + "." + name) for name, child in module.named_children(): From 3238aa8ea9131465d7fed86b76c3b222c93801f2 Mon Sep 17 00:00:00 2001 From: yinjie Date: Sun, 14 Sep 2025 17:07:58 -0400 Subject: [PATCH 04/17] feat: print the observation trace --- traincheck/proxy_wrapper/proxy_observer.py | 14 +++++++++++--- traincheck/proxy_wrapper/subclass.py | 3 +++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/traincheck/proxy_wrapper/proxy_observer.py b/traincheck/proxy_wrapper/proxy_observer.py index 06afcc2b..5dd38867 100644 --- a/traincheck/proxy_wrapper/proxy_observer.py +++ b/traincheck/proxy_wrapper/proxy_observer.py @@ -2,18 +2,26 @@ import typing from traincheck.config.config import should_disable_proxy_dumping +from traincheck.proxy_wrapper.subclass import ProxyParameter from traincheck.utils import typename if typing.TYPE_CHECKING: from traincheck.proxy_wrapper.proxy import Proxy + from traincheck.proxy_wrapper.subclass import ProxyParameter + from .proxy_basics import is_proxied, unproxy_func def observe_proxy_var( - var: "Proxy", + var: typing.Union["Proxy", "ProxyParameter"], phase, observe_api_name: str, ): + # TODO: After fully implement the ProxyParameter, don't need this check + if isinstance(var, ProxyParameter): + var.dump_trace(phase=phase, dump_loc=observe_api_name) + return + # update the proxy object's timestamp var.update_timestamp() @@ -37,9 +45,9 @@ def wrapper(*args, **kwargs): # if the arg is list or tuple, check if it contains proxied object if type(arg) in [list, tuple]: for element in arg: - if is_proxied(element): + if is_proxied(element) or isinstance(element, ProxyParameter): proxied_vars.append(element) - if is_proxied(arg): + if is_proxied(arg) or isinstance(arg, ProxyParameter): proxied_vars.append(arg) # pre observe diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 6d9775d1..eaf222d5 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -79,6 +79,9 @@ def __deepcopy__(self, memo): var_name=self.varname, ) + def dump_trace(self, phase, dump_loc): + print(f"parameter: {self.varname}, phase = {phase}, dump_loc = {dump_loc}") + def proxy_parameter(module: nn.Module, parent_name: str = ""): if in_dynamo(): From 707b2ea321452041cedd238992a3751a26526566 Mon Sep 17 00:00:00 2001 From: yinjie Date: Sun, 14 Sep 2025 22:04:33 -0400 Subject: [PATCH 05/17] feat: dump trace --- traincheck/proxy_wrapper/proxy_observer.py | 4 - traincheck/proxy_wrapper/subclass.py | 152 +++++++++++++++++++-- 2 files changed, 141 insertions(+), 15 deletions(-) diff --git a/traincheck/proxy_wrapper/proxy_observer.py b/traincheck/proxy_wrapper/proxy_observer.py index 5dd38867..68e6b834 100644 --- a/traincheck/proxy_wrapper/proxy_observer.py +++ b/traincheck/proxy_wrapper/proxy_observer.py @@ -17,10 +17,6 @@ def observe_proxy_var( phase, observe_api_name: str, ): - # TODO: After fully implement the ProxyParameter, don't need this check - if isinstance(var, ProxyParameter): - var.dump_trace(phase=phase, dump_loc=observe_api_name) - return # update the proxy object's timestamp var.update_timestamp() diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index eaf222d5..ed690989 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -49,7 +49,25 @@ def is_fake_tensor(x: torch.Tensor) -> bool: class ProxyParameter(torch.nn.Parameter): - def __new__(cls, data, var_name=""): + loglevel = logging.INFO + jsondumper = dumper( + os.path.join(os.getenv("ML_DAIKON_OUTPUT_DIR", "."), "proxy_log.json") # type: ignore + ) + + def __new__( + cls, + data, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + var_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, + ): if in_dynamo() or is_fake_tensor(data): if isinstance(data, nn.Parameter): return data @@ -60,14 +78,58 @@ def __new__(cls, data, var_name=""): return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad) - def __init__(self, data, var_name=""): - self.__dict__["varname"] = var_name - print(f"init: {self.varname}") + def __init__( + self, + data, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + var_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, + ): super().__init__() + # Access proxy attribute: since we are wrapping the getattr method, we need to access the attribute directly + self.__dict__["process_id"] = os.getpid() + self.__dict__["thread_id"] = threading.current_thread().ident + self.__dict__["logdir"] = logdir + self.__dict__["log_level"] = log_level + # TODO + # self.__dict__["meta_vars"] = {} + # self.__dict__["is_traincheck_proxied_obj"] = True + # TODO + # self.__dict__["recurse"] = recurse + self.__dict__["var_name"] = var_name + # TODO + # self.__dict__["old_value"] = None + # self.__dict__["old_meta_vars"] = None + + current_time = get_timestamp_ns() + + self.__dict__["last_update_timestamp"] = current_time + + print(f"init: {self.var_name}") + if should_dump_trace: + if from_call: + phase = "call" + + if from_iter: + phase = "iter" + # if the object is generated from getattr, then do not dump it + else: + phase = "update" + self.dump_trace(phase=phase, dump_loc="initing") def __setattr__(self, name, value): - print(f"paremeter: {self.varname}, name = {name}, value = {value}") - + print(f"paremeter: {self.var_name}, name = {name}, value = {value}") + self.dump_trace( + phase="update", + dump_loc=f"__setattr__ (attribute '{name}')", + ) return super().__setattr__(name, value) def __deepcopy__(self, memo): @@ -76,17 +138,85 @@ def __deepcopy__(self, memo): return self return type(self)( data.clone(memory_format=torch.preserve_format), - var_name=self.varname, + var_name=self.var_name, ) + def update_timestamp(self): + # Update the timestamp of the object, should be called when the object is updated, e.g. __setattr__ and observer + current_time = get_timestamp_ns() + self.__dict__["last_update_timestamp"] = current_time + # TODO: + # Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time + + def register_object(self): + # get_global_registry().add_var(self, self.__dict__["var_name"]) + # TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object + pass + def dump_trace(self, phase, dump_loc): - print(f"parameter: {self.varname}, phase = {phase}, dump_loc = {dump_loc}") + print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") + # TODO + var_name = self.__dict__["var_name"] + # assert var_name is not None # '' is allowed as a var_name (root object) + # filter_by_tensor_version = proxy_config.dump_info_config[ + # "filter_by_tensor_version" + # ] + # if filter_by_tensor_version and phase == "update": + # if hasattr(obj, "_version"): + # if obj._version == Proxy.var_dict[self.__dict__["var_name"]].version: + # return + + last_update_timestamp = self.__dict__["last_update_timestamp"] + + # TODO + # if not isinstance(obj, torch.nn.Module): + self.jsondumper.dump_json( + process_id=self.process_id, + thread_id=self.thread_id, + time=last_update_timestamp, + meta_vars=get_meta_vars(self), + var_name=var_name, + # TODO + var_type="torch.nn.Parameter", + change_type=phase, + # TODO: verify dump_attributes + var_attributes=dump_attributes(self, self.data), + dump_loc=dump_loc, + ) -def proxy_parameter(module: nn.Module, parent_name: str = ""): +def proxy_parameter( + module: nn.Module, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + parent_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, +): if in_dynamo(): return for name, t in list(module.named_parameters(recurse=False)): - module._parameters[name] = ProxyParameter(t, parent_name + "." + name) + module._parameters[name] = ProxyParameter( + t, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) for name, child in module.named_children(): - proxy_parameter(child, parent_name + "." + name) + proxy_parameter( + child, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) From 63dc25f45b581e2240a062ce3ab3e45b201c0469 Mon Sep 17 00:00:00 2001 From: yinjie Date: Mon, 15 Sep 2025 11:29:04 -0400 Subject: [PATCH 06/17] fix: information collected --- traincheck/instrumentor/dumper.py | 4 ++++ traincheck/instrumentor/tracer.py | 19 ++++++++++++++----- traincheck/invariant/symbolic_value.py | 6 +++--- traincheck/proxy_wrapper/proxy_basics.py | 9 +++++++++ traincheck/proxy_wrapper/proxy_config.py | 11 +++++++++++ traincheck/proxy_wrapper/subclass.py | 11 ++++++----- 6 files changed, 47 insertions(+), 13 deletions(-) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index f6bf03fc..f8472f4f 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -21,6 +21,7 @@ from traincheck.proxy_wrapper.proxy_config import ( attribute_black_list, primitive_types, + proxy_attribute, tensor_dump_format, ) from traincheck.utils import get_timestamp_ns, typename @@ -335,6 +336,9 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict ): continue + if attr_name in proxy_attribute: + continue + if attr_name in attribute_black_list: continue diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index cf28785a..53d06747 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -29,7 +29,11 @@ funcs_to_be_replaced, is_funcs_to_be_unproxied, ) -from traincheck.proxy_wrapper.proxy_basics import is_proxied, unproxy_func +from traincheck.proxy_wrapper.proxy_basics import ( + is_proxied, + is_proxyparamtetr, + unproxy_func, +) from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer from traincheck.proxy_wrapper.proxy_registry import get_global_registry from traincheck.utils import get_timestamp_ns, get_unique_id, typename @@ -215,7 +219,7 @@ def global_wrapper( def find_proxy_in_args(args): for i, arg in enumerate(args): - if is_proxied(arg): + if is_proxied(arg) or is_proxyparamtetr(arg): proxy_in_args.append(arg) elif type(arg) in [list, tuple]: find_proxy_in_args(arg) @@ -234,9 +238,14 @@ def find_proxy_in_args(args): if "proxy_obj_names" not in pre_record: pre_record["proxy_obj_names"] = [] for proxy in proxy_in_args: - pre_record["proxy_obj_names"].append( - [proxy.__dict__["var_name"], type(proxy._obj).__name__] - ) + if is_proxyparamtetr(proxy): + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], "Parameter"] + ) + else: + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], type(proxy._obj).__name__] + ) if dump_args: dict_args_kwargs = to_dict_args_kwargs(args, kwargs, dump_args_config) pre_record["args"] = dict_args_kwargs["args"] diff --git a/traincheck/invariant/symbolic_value.py b/traincheck/invariant/symbolic_value.py index 2cdbb174..d3a522d4 100644 --- a/traincheck/invariant/symbolic_value.py +++ b/traincheck/invariant/symbolic_value.py @@ -117,9 +117,9 @@ def generalize_values(values: list[type]) -> MD_NONE | type | str: min_value = min(all_non_none_values) # type: ignore max_value = max(all_non_none_values) # type: ignore - assert ( - min_value != max_value - ), "Min and max values are the same, you don't need to generalize the values" + # assert ( + # min_value != max_value + # ), "Min and max values are the same, you don't need to generalize the values" if min_value > 0: return ABOVE_ZERO elif min_value >= 0: diff --git a/traincheck/proxy_wrapper/proxy_basics.py b/traincheck/proxy_wrapper/proxy_basics.py index 11f8162b..8fb3dd01 100644 --- a/traincheck/proxy_wrapper/proxy_basics.py +++ b/traincheck/proxy_wrapper/proxy_basics.py @@ -14,6 +14,15 @@ def is_proxied(obj): return False +def is_proxyparamtetr(obj): + try: + if obj is not None and "is_traincheck_proxyparameter" in obj.__dict__: + return True + except Exception: + return False + return False + + def unproxy_arg(arg, inspect_torch_module=False): if is_proxied(arg): diff --git a/traincheck/proxy_wrapper/proxy_config.py b/traincheck/proxy_wrapper/proxy_config.py index 57c2d4d1..66ce6d7c 100644 --- a/traincheck/proxy_wrapper/proxy_config.py +++ b/traincheck/proxy_wrapper/proxy_config.py @@ -49,3 +49,14 @@ "real", ] attribute_black_list = tensor_attribute_black_list +# TODO +proxy_attribute = [ + "process_id", + "thread_id", + "logdir", + "log_level", + "loglevel", + "is_traincheck_proxyparameter", + "var_name", + "last_update_timestamp", +] diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index ed690989..7914e8e3 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -101,6 +101,7 @@ def __init__( # TODO # self.__dict__["meta_vars"] = {} # self.__dict__["is_traincheck_proxied_obj"] = True + self.__dict__["is_traincheck_proxyparameter"] = True # TODO # self.__dict__["recurse"] = recurse self.__dict__["var_name"] = var_name @@ -112,7 +113,7 @@ def __init__( self.__dict__["last_update_timestamp"] = current_time - print(f"init: {self.var_name}") + # print(f"init: {self.var_name}") if should_dump_trace: if from_call: phase = "call" @@ -125,12 +126,12 @@ def __init__( self.dump_trace(phase=phase, dump_loc="initing") def __setattr__(self, name, value): - print(f"paremeter: {self.var_name}, name = {name}, value = {value}") + # print(f"paremeter: {self.var_name}, name = {name}, value = {value}") + super().__setattr__(name, value) self.dump_trace( phase="update", dump_loc=f"__setattr__ (attribute '{name}')", ) - return super().__setattr__(name, value) def __deepcopy__(self, memo): data = self.data @@ -154,7 +155,7 @@ def register_object(self): pass def dump_trace(self, phase, dump_loc): - print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") + # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") # TODO var_name = self.__dict__["var_name"] # assert var_name is not None # '' is allowed as a var_name (root object) @@ -180,7 +181,7 @@ def dump_trace(self, phase, dump_loc): var_type="torch.nn.Parameter", change_type=phase, # TODO: verify dump_attributes - var_attributes=dump_attributes(self, self.data), + var_attributes=dump_attributes(self, self), dump_loc=dump_loc, ) From 90b8eaa02ce30367d279a6311dc5719fe297e569 Mon Sep 17 00:00:00 2001 From: yinjie Date: Mon, 15 Sep 2025 22:15:14 -0400 Subject: [PATCH 07/17] fix: update time when setattr --- traincheck/invariant/symbolic_value.py | 6 +++--- traincheck/proxy_wrapper/subclass.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/traincheck/invariant/symbolic_value.py b/traincheck/invariant/symbolic_value.py index d3a522d4..2cdbb174 100644 --- a/traincheck/invariant/symbolic_value.py +++ b/traincheck/invariant/symbolic_value.py @@ -117,9 +117,9 @@ def generalize_values(values: list[type]) -> MD_NONE | type | str: min_value = min(all_non_none_values) # type: ignore max_value = max(all_non_none_values) # type: ignore - # assert ( - # min_value != max_value - # ), "Min and max values are the same, you don't need to generalize the values" + assert ( + min_value != max_value + ), "Min and max values are the same, you don't need to generalize the values" if min_value > 0: return ABOVE_ZERO elif min_value >= 0: diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 7914e8e3..70e71cd3 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -128,6 +128,7 @@ def __init__( def __setattr__(self, name, value): # print(f"paremeter: {self.var_name}, name = {name}, value = {value}") super().__setattr__(name, value) + self.update_timestamp() self.dump_trace( phase="update", dump_loc=f"__setattr__ (attribute '{name}')", From 56f17e36f456377e35e6122e0737bc75591244c4 Mon Sep 17 00:00:00 2001 From: yinjie Date: Tue, 16 Sep 2025 11:15:06 -0400 Subject: [PATCH 08/17] feat: instrument proxyparameter --- traincheck/collect_trace.py | 2 +- traincheck/instrumentor/source_file.py | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 48bcfe87..c250f3a9 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -350,7 +350,7 @@ def main(): parser.add_argument( "--model-tracker-style", type=str, - choices=["sampler", "proxy"], + choices=["sampler", "proxy", "proxyparameter"], default="proxy", ) parser.add_argument( diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 4de57416..62ca16bd 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -271,7 +271,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]: def instrument_all_model_assignments( - source_code: str, model_name: str, mode: str + source_code: str, model_name: str, mode: str | None ) -> str: """ Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement @@ -292,6 +292,11 @@ def instrument_all_model_assignments( instr_statement = ast.parse( f"{model_name}_sampler = VarSampler({model_name}, var_name='{model_name}')" ) + elif mode == "proxyparameter": + instr_statement = ast.parse( + f"proxy_parameter({model_name}, logdir=proxy_config.proxy_log_dir, parent_name='{model_name}')" + ) + else: raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']") @@ -348,6 +353,7 @@ def instrument_model_tracker_proxy( models_to_track: list[str], adjusted_proxy_config: list[dict[str, int | bool | str]], no_auto_var_instr: bool, + model_tracker_style: str | None, ): auto_observer_config: dict[str, int | bool | str] = adjusted_proxy_config[0] proxy_basic_config: dict[str, int | bool | str] = adjusted_proxy_config[1] @@ -373,8 +379,13 @@ def instrument_model_tracker_proxy( tensor_dump_format.update({tensor_dump_format}) """ - proxy_start_code += """ + if model_tracker_style == "proxy": + proxy_start_code += """ from traincheck.proxy_wrapper.proxy import Proxy +""" + else: + proxy_start_code += """ +from traincheck.proxy_wrapper.subclass import proxy_parameter """ if auto_observer_config["enable_auto_observer"]: @@ -435,7 +446,7 @@ def instrument_model_tracker_proxy( if not no_auto_var_instr: for model in models_to_track: instrumented_source = instrument_all_model_assignments( - instrumented_source, model, "proxy" + instrumented_source, model, model_tracker_style ) code_head, code_tail = get_code_head_and_tail(instrumented_source) @@ -840,13 +851,15 @@ def instrument_file( assert model_tracker_style in [ "proxy", "sampler", + "proxyparameter", ], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler']" - if model_tracker_style == "proxy": + if model_tracker_style == "proxy" or model_tracker_style == "proxyparameter": instrumented_source = instrument_model_tracker_proxy( instrumented_source, models_to_track, adjusted_proxy_config, no_auto_var_instr, + model_tracker_style, ) else: instrumented_source = instrument_model_tracker_sampler( From 77ebebbb23005f2394c258062206f9a521a3fd07 Mon Sep 17 00:00:00 2001 From: yinjie Date: Tue, 16 Sep 2025 11:46:01 -0400 Subject: [PATCH 09/17] fix: scan_proxy_in_args --- traincheck/collect_trace.py | 2 +- traincheck/proxy_wrapper/subclass.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index c250f3a9..69c3d7c0 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -444,7 +444,7 @@ def main(): scan_proxy_in_args = not args.disable_scan_proxy_in_args # if no proxy tracking specified in the arguments, disable the scan_proxy_in_args - if not args.models_to_track or args.model_tracker_style != "proxy": + if not args.models_to_track or args.model_tracker_style == "sampler": scan_proxy_in_args = False if args.invariants: diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 70e71cd3..205fbac0 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -72,7 +72,7 @@ def __new__( if isinstance(data, nn.Parameter): return data return nn.Parameter(data, requires_grad=data.requires_grad) - + # TODO: verify if isinstance(data, ProxyParameter): return data From 88d2ab7b67671a0f900e36f6700945731741d84f Mon Sep 17 00:00:00 2001 From: yinjie Date: Wed, 17 Sep 2025 10:58:43 -0400 Subject: [PATCH 10/17] fix: remove the used import --- traincheck/proxy_wrapper/subclass.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 205fbac0..5633fd03 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -1,26 +1,17 @@ -import copy import logging import os import threading -import time -import types -from typing import Dict import torch from torch import nn -import traincheck.config.config as general_config -import traincheck.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables -import traincheck.proxy_wrapper.proxy_methods as proxy_methods from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars -from traincheck.utils import get_timestamp_ns, typename +from traincheck.utils import get_timestamp_ns from .dumper import json_dumper as dumper -from .proxy_basics import unproxy_arg, unproxy_args_kwargs -from .proxy_handler import PROXY_SUPPORT_OBJ_TYPES # from .proxy_registry import get_global_registry -from .utils import print_debug +# from .utils import print_debug def in_dynamo() -> bool: From e67a5683e1f33d0763068d2b82fe5b38ad3ef2f7 Mon Sep 17 00:00:00 2001 From: yinjie Date: Sat, 20 Sep 2025 17:53:13 -0400 Subject: [PATCH 11/17] fix: torch.compile compatibility --- traincheck/config/config.py | 9 +++ traincheck/instrumentor/dumper.py | 81 +++++++++++++++++++--- traincheck/instrumentor/tracer.py | 6 +- traincheck/proxy_wrapper/proxy_basics.py | 42 ++++++++++- traincheck/proxy_wrapper/proxy_observer.py | 6 +- traincheck/proxy_wrapper/subclass.py | 21 +----- traincheck/utils.py | 8 +++ 7 files changed, 139 insertions(+), 34 deletions(-) diff --git a/traincheck/config/config.py b/traincheck/config/config.py index 51c457af..6b284d99 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -249,3 +249,12 @@ def should_disable_proxy_dumping() -> bool: "preprocessing", "postprocessing", } + +COMPILE_INTERNAL_MODULE = ( + "torch.fx", + # "torch._dynamo", + "torch._inductor", + "torch._subclasses", + "torch._higher_order_ops", + "torch.utils._sympy", +) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index f8472f4f..04935e8a 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -18,13 +18,14 @@ # if torch.cuda.is_available(): from traincheck.proxy_wrapper.hash import tensor_hash +from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor from traincheck.proxy_wrapper.proxy_config import ( attribute_black_list, primitive_types, proxy_attribute, tensor_dump_format, ) -from traincheck.utils import get_timestamp_ns, typename +from traincheck.utils import get_timestamp_ns, typename, typename_compile DEBUG = os.environ.get("ML_DAIKON_DEBUG", False) THREAD_DATA = threading.local() @@ -45,12 +46,48 @@ logger = logging.getLogger(__name__) +def _json_default(o): + try: + if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"): + return str(o) + + if isinstance(o, torch.device): + return str(o) + if isinstance(o, torch.dtype): + return str(o) + if isinstance(o, torch.Size): + out = [] + for d in o: + try: + out.append(int(d)) + except Exception: + out.append(str(d)) + return out + except Exception: + pass + + if isinstance(o, set): + return list(o) + if isinstance(o, tuple): + return list(o) + + try: + import numpy as np + + if isinstance(o, (np.generic,)): + return o.item() + except Exception: + pass + + return repr(o) + + def serialize(obj_dict: dict[str, object | str]) -> str: try: - return orjson.dumps(obj_dict).decode("utf-8") + return orjson.dumps(obj_dict, default=_json_default).decode("utf-8") except Exception: # if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json - return json.dumps(obj_dict) + return json.dumps(obj_dict, default=_json_default) def monitor_main_thread(main_thread, stop_event): @@ -350,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict attr = safe_getattr(var, attr_name) if attr is NOT_FOUND: - logger.warning( - f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." - ) - if var_type not in skip_attrs_due_to_errs: - skip_attrs_due_to_errs[var_type] = set() - skip_attrs_due_to_errs[var_type].add(attr_name) + if not ( + attr_name == "data" + and isinstance(var, torch.Tensor) + and not include_tensor_data + ): + logger.warning( + f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." + ) + if var_type not in skip_attrs_due_to_errs: + skip_attrs_due_to_errs[var_type] = set() + skip_attrs_due_to_errs[var_type].add(attr_name) continue attr_name = str(attr_name) @@ -399,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict return result +def convert_fake_tensor_to_dict(var): + try: + shape = tuple(var.shape) + except Exception: + shape = None + try: + dtype = str(var.dtype) + except Exception: + dtype = None + return { + "fake": True, + "shape": shape, + "dtype": dtype, + } + + def obj_to_serializable(obj, dump_config=None) -> dict[str, object]: + if is_fake_tensor(obj): + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} if ( type(obj) in skip_type_due_to_recursion and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD @@ -433,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]: If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`. """ + if is_fake_tensor(obj): + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} + if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721 return obj_to_serializable(obj, dump_config=dump_config) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 53d06747..a4812b7d 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -31,7 +31,7 @@ ) from traincheck.proxy_wrapper.proxy_basics import ( is_proxied, - is_proxyparamtetr, + is_proxyparameter, unproxy_func, ) from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer @@ -219,7 +219,7 @@ def global_wrapper( def find_proxy_in_args(args): for i, arg in enumerate(args): - if is_proxied(arg) or is_proxyparamtetr(arg): + if is_proxied(arg) or is_proxyparameter(arg): proxy_in_args.append(arg) elif type(arg) in [list, tuple]: find_proxy_in_args(arg) @@ -238,7 +238,7 @@ def find_proxy_in_args(args): if "proxy_obj_names" not in pre_record: pre_record["proxy_obj_names"] = [] for proxy in proxy_in_args: - if is_proxyparamtetr(proxy): + if is_proxyparameter(proxy): pre_record["proxy_obj_names"].append( [proxy.__dict__["var_name"], "Parameter"] ) diff --git a/traincheck/proxy_wrapper/proxy_basics.py b/traincheck/proxy_wrapper/proxy_basics.py index 8fb3dd01..071c6f4e 100644 --- a/traincheck/proxy_wrapper/proxy_basics.py +++ b/traincheck/proxy_wrapper/proxy_basics.py @@ -4,9 +4,47 @@ import astor +from traincheck.config.config import COMPILE_INTERNAL_MODULE + + +def is_compile_internal_module(obj): + mod = getattr(type(obj), "__module__", "") or "" + if any(mod.startswith(p) for p in COMPILE_INTERNAL_MODULE): + return True + name = type(obj).__name__ + if mod.startswith("torch._dynamo") and name != "OptimizedModule": + return True + return False + + +def is_fake_tensor(x) -> bool: + try: + from torch._subclasses.fake_tensor import FakeTensor + from torch.fx import Proxy as FxProxy + + if isinstance(x, FakeTensor): + return True + if isinstance(x, FxProxy): + return True + except Exception: + pass + + try: + if is_compile_internal_module(x): + return True + except Exception: + return True + + try: + return x.device.type == "meta" + except Exception: + return True + def is_proxied(obj): try: + if is_fake_tensor(obj): + return False if obj is not None and "is_traincheck_proxied_obj" in obj.__dict__: return True except Exception: @@ -14,8 +52,10 @@ def is_proxied(obj): return False -def is_proxyparamtetr(obj): +def is_proxyparameter(obj): try: + if is_fake_tensor(obj): + return False if obj is not None and "is_traincheck_proxyparameter" in obj.__dict__: return True except Exception: diff --git a/traincheck/proxy_wrapper/proxy_observer.py b/traincheck/proxy_wrapper/proxy_observer.py index 68e6b834..5316fed6 100644 --- a/traincheck/proxy_wrapper/proxy_observer.py +++ b/traincheck/proxy_wrapper/proxy_observer.py @@ -9,7 +9,7 @@ from traincheck.proxy_wrapper.proxy import Proxy from traincheck.proxy_wrapper.subclass import ProxyParameter -from .proxy_basics import is_proxied, unproxy_func +from .proxy_basics import is_proxied, is_proxyparameter, unproxy_func def observe_proxy_var( @@ -41,9 +41,9 @@ def wrapper(*args, **kwargs): # if the arg is list or tuple, check if it contains proxied object if type(arg) in [list, tuple]: for element in arg: - if is_proxied(element) or isinstance(element, ProxyParameter): + if is_proxied(element) or is_proxyparameter(element): proxied_vars.append(element) - if is_proxied(arg) or isinstance(arg, ProxyParameter): + if is_proxied(arg) or is_proxyparameter(arg): proxied_vars.append(arg) # pre observe diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 5633fd03..90493584 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -9,6 +9,7 @@ from traincheck.utils import get_timestamp_ns from .dumper import json_dumper as dumper +from .proxy_basics import is_fake_tensor # from .proxy_registry import get_global_registry # from .utils import print_debug @@ -23,22 +24,6 @@ def in_dynamo() -> bool: return False -def is_fake_tensor(x: torch.Tensor) -> bool: - try: - from torch._subclasses.fake_tensor import FakeTensor # 2.x - - if isinstance(x, FakeTensor): - return True - except Exception: - pass - if getattr(x, "fake_mode", None) is not None: - return True - if getattr(x, "_is_fake", False): - return True - - return isinstance(x, torch.Tensor) and x.device.type == "meta" - - class ProxyParameter(torch.nn.Parameter): loglevel = logging.INFO jsondumper = dumper( @@ -59,13 +44,13 @@ def __new__( # TODO # from_copy=False, ): + if isinstance(data, ProxyParameter): + return data if in_dynamo() or is_fake_tensor(data): if isinstance(data, nn.Parameter): return data return nn.Parameter(data, requires_grad=data.requires_grad) # TODO: verify - if isinstance(data, ProxyParameter): - return data return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad) diff --git a/traincheck/utils.py b/traincheck/utils.py index 9e332094..944fd989 100644 --- a/traincheck/utils.py +++ b/traincheck/utils.py @@ -35,6 +35,14 @@ def safe_getattr(obj, attr, default=None): raise +def typename_compile(o): + try: + mod = getattr(type(o), "__module__", "") or "" + return f"{mod}.{type(o).__name__}" + except Exception: + return "compile_stage" + + def typename(o, is_runtime=False): if isinstance(o, torch.nn.Parameter): return "torch.nn.Parameter" From 49dd5f04d18af11534932cedcf1001f783a949cf Mon Sep 17 00:00:00 2001 From: yinjie Date: Sat, 20 Sep 2025 17:58:23 -0400 Subject: [PATCH 12/17] fix: error message includes proxyparameter --- traincheck/instrumentor/source_file.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 62ca16bd..0a386eb8 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -298,7 +298,9 @@ def instrument_all_model_assignments( ) else: - raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']") + raise ValueError( + f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'proxyparameter']" + ) # find all assignment statements to `model` assignments = [] From 8f5c979392dd889b71e5a1fd490eb3c2ba4da85e Mon Sep 17 00:00:00 2001 From: yinjie Date: Sun, 21 Sep 2025 10:11:09 -0400 Subject: [PATCH 13/17] feat: compile mode --- traincheck/collect_trace.py | 7 +++++++ traincheck/config/config.py | 1 + traincheck/instrumentor/source_file.py | 6 ++++++ traincheck/proxy_wrapper/proxy_basics.py | 11 ++++++++--- 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 69c3d7c0..b128e085 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -371,6 +371,11 @@ def main(): action="store_true", help="Disable automatic variable instrumentation, necessary when the default behavior of the instrumentor is not desired (e.g. cause segmentation fault)", ) + parser.add_argument( + "--use-torch-compile", + action="store_true", + help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility", + ) args = parser.parse_args() @@ -481,6 +486,7 @@ def main(): output_dir=output_dir, instr_descriptors=args.instr_descriptors, no_auto_var_instr=args.no_auto_var_instr, + use_torch_compile=args.use_torch_compile, ) else: source_code = instrumentor.instrument_file( @@ -496,6 +502,7 @@ def main(): output_dir=output_dir, instr_descriptors=args.instr_descriptors, no_auto_var_instr=args.no_auto_var_instr, + use_torch_compile=args.use_torch_compile, ) if args.copy_all_files: diff --git a/traincheck/config/config.py b/traincheck/config/config.py index 6b284d99..55a6295d 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -238,6 +238,7 @@ def should_disable_proxy_dumping() -> bool: INSTR_DESCRIPTORS = False +USE_TORCH_COMPILE = False ALL_STAGE_NAMES = { "init", diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 0a386eb8..9f2bcb43 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -810,6 +810,7 @@ def instrument_file( output_dir: str, instr_descriptors: bool, no_auto_var_instr: bool, + use_torch_compile: bool, ) -> str: """ Instruments the given file and returns the instrumented source code. @@ -847,6 +848,11 @@ def instrument_file( import traincheck.config.config as general_config general_config.INSTR_DESCRIPTORS = {instr_descriptors} """ + if use_torch_compile: + torch_compile_config_update = """ +general_config.USE_TORCH_COMPILE = True +""" + general_config_update = general_config_update + torch_compile_config_update # TODO: move the INSTR_DESCRIPTORS to the instr_opts file if models_to_track: diff --git a/traincheck/proxy_wrapper/proxy_basics.py b/traincheck/proxy_wrapper/proxy_basics.py index 071c6f4e..dd3014bb 100644 --- a/traincheck/proxy_wrapper/proxy_basics.py +++ b/traincheck/proxy_wrapper/proxy_basics.py @@ -4,12 +4,12 @@ import astor -from traincheck.config.config import COMPILE_INTERNAL_MODULE +import traincheck.config.config as config def is_compile_internal_module(obj): mod = getattr(type(obj), "__module__", "") or "" - if any(mod.startswith(p) for p in COMPILE_INTERNAL_MODULE): + if any(mod.startswith(p) for p in config.COMPILE_INTERNAL_MODULE): return True name = type(obj).__name__ if mod.startswith("torch._dynamo") and name != "OptimizedModule": @@ -18,6 +18,8 @@ def is_compile_internal_module(obj): def is_fake_tensor(x) -> bool: + if not config.USE_TORCH_COMPILE: + return False try: from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Proxy as FxProxy @@ -36,10 +38,13 @@ def is_fake_tensor(x) -> bool: return True try: - return x.device.type == "meta" + if x.device.type == "meta": + return True except Exception: return True + return False + def is_proxied(obj): try: From 79b4a0749ba7014d0685fad6cf179eacabcea0c0 Mon Sep 17 00:00:00 2001 From: Yuxuan Date: Mon, 29 Sep 2025 13:15:14 -0400 Subject: [PATCH 14/17] fix: checker trace path parsing --- traincheck/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/traincheck/checker.py b/traincheck/checker.py index 62045468..dd8816f5 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -153,7 +153,7 @@ def main(): trace_parent_folders = [] if args.traces is not None: logger.info("Reading traces from %s", "\n".join(args.traces)) - trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces[0]))] + trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces))] traces.append(read_trace_file(args.traces)) if args.trace_folders is not None: for trace_folder in args.trace_folders: From 84bb44ca37d56bd0374632dff212ad7b40fbc951 Mon Sep 17 00:00:00 2001 From: Yuxuan Date: Mon, 29 Sep 2025 16:14:39 -0400 Subject: [PATCH 15/17] fix: make sure all python-level states are copied over when subclassing tensors --- traincheck/proxy_wrapper/subclass.py | 42 ++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index 90493584..cf0432ba 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -46,13 +46,49 @@ def __new__( ): if isinstance(data, ProxyParameter): return data + if in_dynamo() or is_fake_tensor(data): + # we do not proxy the parameter if we are in dynamo or the tensor is a fake tensor if isinstance(data, nn.Parameter): return data return nn.Parameter(data, requires_grad=data.requires_grad) - # TODO: verify - return torch.Tensor._make_subclass(cls, data.detach(), data.requires_grad) + requires_grad = getattr(data, "requires_grad", False) + tensor_grad = getattr(data, "grad", None) + + # When wrapping an existing Parameter we need to preserve any Python level + # attributes (e.g. hooks, user defined flags, ``grad``) so that the proxy + # behaves identically to the original parameter. ``Parameter.__new__`` + # returns a fresh instance, so we snapshot the metadata from ``data`` and + # replay it on the new ProxyParameter via the base Tensor ``__setattr__`` + # to avoid triggering the logging logic implemented in this class. + snapshot: dict = {} + + if isinstance(data, nn.Parameter): + snapshot = dict(getattr(data, "__dict__", {})) + base_tensor = data.detach() + elif isinstance(data, torch.Tensor): + base_tensor = data.detach() + else: + base_tensor = torch.as_tensor(data) + + proxied = super().__new__(cls, base_tensor, requires_grad=requires_grad) + + if snapshot: + tensor_setattr = torch.Tensor.__setattr__ + for name, value in snapshot.items(): + if name == "grad": + continue + try: + tensor_setattr(proxied, name, value) + except AttributeError: + # Some slots (e.g. torch internals) are read-only; skip them. + continue + + if tensor_grad is not None: + torch.Tensor.__setattr__(proxied, "grad", tensor_grad) + + return proxied def __init__( self, @@ -107,7 +143,7 @@ def __setattr__(self, name, value): self.update_timestamp() self.dump_trace( phase="update", - dump_loc=f"__setattr__ (attribute '{name}')", + dump_loc=f"__setattr__ (attribute '{name}' to {value})", ) def __deepcopy__(self, memo): From d542af247d37d3de828e66d32e7c561ed2508c66 Mon Sep 17 00:00:00 2001 From: Yuxuan Date: Mon, 29 Sep 2025 16:18:22 -0400 Subject: [PATCH 16/17] add: refine setattr log for better debugging --- traincheck/invariant/precondition.py | 1 - traincheck/proxy_wrapper/proxy.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index ad507dcb..b2040d83 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -537,7 +537,6 @@ def find_precondition_from_single_group( if len(example) == 0: raise ValueError("Empty example found in positive examples") - # HACK: in ConsistencyRelation in order to avoid the field used in the invariant, we need to skip the field in the precondition. It is up to the caller to provide the keys to skip. We should try to refactor this to have a more generic solution. earliest_time = example[0]["time"] process_id = example[0]["process_id"] thread_id = example[0]["thread_id"] diff --git a/traincheck/proxy_wrapper/proxy.py b/traincheck/proxy_wrapper/proxy.py index 68a839d9..3181aa24 100644 --- a/traincheck/proxy_wrapper/proxy.py +++ b/traincheck/proxy_wrapper/proxy.py @@ -369,7 +369,7 @@ def __setattr__(self, name, value): self.dump_trace( phase="update", - dump_loc=f"__setattr__ (attribute '{name}')", + dump_loc=f"__setattr__ (attribute '{name}' to {value}')", ) def __getitem__(self, key): From a8e201566b798528b300df9e41092e1eb42d7cde Mon Sep 17 00:00:00 2001 From: Yuxuan Date: Tue, 30 Sep 2025 13:04:38 -0400 Subject: [PATCH 17/17] add: use consistent trace dumping logic for subclass wrapper --- traincheck/proxy_wrapper/proxy.py | 2 +- traincheck/proxy_wrapper/subclass.py | 33 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/traincheck/proxy_wrapper/proxy.py b/traincheck/proxy_wrapper/proxy.py index 3181aa24..68a839d9 100644 --- a/traincheck/proxy_wrapper/proxy.py +++ b/traincheck/proxy_wrapper/proxy.py @@ -369,7 +369,7 @@ def __setattr__(self, name, value): self.dump_trace( phase="update", - dump_loc=f"__setattr__ (attribute '{name}' to {value}')", + dump_loc=f"__setattr__ (attribute '{name}')", ) def __getitem__(self, key): diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py index cf0432ba..f335ba50 100644 --- a/traincheck/proxy_wrapper/subclass.py +++ b/traincheck/proxy_wrapper/subclass.py @@ -5,10 +5,11 @@ import torch from torch import nn +from traincheck.instrumentor.dumper import dump_trace_VAR +from traincheck.instrumentor.tracer import TraceLineType from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars from traincheck.utils import get_timestamp_ns -from .dumper import json_dumper as dumper from .proxy_basics import is_fake_tensor # from .proxy_registry import get_global_registry @@ -26,9 +27,6 @@ def in_dynamo() -> bool: class ProxyParameter(torch.nn.Parameter): loglevel = logging.INFO - jsondumper = dumper( - os.path.join(os.getenv("ML_DAIKON_OUTPUT_DIR", "."), "proxy_log.json") # type: ignore - ) def __new__( cls, @@ -143,7 +141,7 @@ def __setattr__(self, name, value): self.update_timestamp() self.dump_trace( phase="update", - dump_loc=f"__setattr__ (attribute '{name}' to {value})", + dump_loc=f"__setattr__ (attribute '{name}')", ) def __deepcopy__(self, memo): @@ -184,18 +182,19 @@ def dump_trace(self, phase, dump_loc): # TODO # if not isinstance(obj, torch.nn.Module): - self.jsondumper.dump_json( - process_id=self.process_id, - thread_id=self.thread_id, - time=last_update_timestamp, - meta_vars=get_meta_vars(self), - var_name=var_name, - # TODO - var_type="torch.nn.Parameter", - change_type=phase, - # TODO: verify dump_attributes - var_attributes=dump_attributes(self, self), - dump_loc=dump_loc, + dump_trace_VAR( + { + "process_id": self.process_id, + "thread_id": self.thread_id, + "time": last_update_timestamp, + "meta_vars": get_meta_vars(self), + "var_name": var_name, + "var_type": "torch.nn.Parameter", + "mode": phase, + "dump_loc": dump_loc, + "attributes": dump_attributes(self, self), + "type": TraceLineType.STATE_CHANGE, + } )