diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index cff346b9f4aa..42e3d66a5e70 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -1785,11 +1785,11 @@ def __init__(self, outputs, inputs, params=None): syms, self._in_format = _flatten(inputs, "input") out, self._out_format = _flatten(outputs, "output") - input_names = set() + input_name_set = set() for i in syms: assert len(i.get_internals().list_outputs()) == 1, \ "Input symbols must be variable, but %s is an output of operators"%str(i) - input_names.add(i.name) + input_name_set.add(i.name) # check if any symbol is row_sparse row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse'] @@ -1806,35 +1806,53 @@ def __init__(self, outputs, inputs, params=None): # Infer type of parameters. Without this, every parameter will be created with # default type i.e., fp32 - arg_params = out.list_arguments() - aux_params = out.list_auxiliary_states() + arg_param_li = out.list_arguments() + aux_param_li = out.list_auxiliary_states() + input_di = out.get_inputs() - arg_types, aux_types = _infer_param_types(syms, out, arg_params, aux_params) + arg_type_li, aux_type_li = _infer_param_types(syms, out, arg_param_li, aux_param_li) if params is None: params = {} - unused_params = set(params.keys()) - set(arg_params) - set(aux_params) + unused_params = set(params.keys()) - set(arg_param_li) - set(aux_param_li) if len(unused_params) > 0: raise ValueError('{} params are unused by the model.'.format(unused_params)) self._reg_params = params + def _extract_initializer(_s_): + _initer_json_ = _s_.list_attr().get('__init__') + if _initer_json_ is None: + return None + try: + _type_str_, _args_di_ = json.loads(_initer_json_) + except json.JSONDecodeError as e: + _type_str_, _args_di_ = _initer_json_, {} + return initializer.create(_type_str_, **_args_di_) - for i, arg in enumerate(arg_params): + for i, arg in enumerate(arg_param_li): if arg in self._reg_params: - self._reg_params[arg]._check_and_setattr(allow_deferred_init=True, dtype=arg_types[i]) + self._reg_params[arg]._check_and_setattr(allow_deferred_init=True, dtype=arg_type_li[i]) if self._reg_params[arg]._var is None: self._reg_params[arg]._var_name = arg - elif arg not in input_names: - self._reg_params[arg] = Parameter(name=arg, allow_deferred_init=True, dtype=arg_types[i]) + elif arg not in input_name_set: + sym_ = input_di[arg] + sym_attr = sym_.list_attr() + self._reg_params[arg] = Parameter( + name=arg, + init=_extract_initializer(sym_), + lr_mult=float(sym_attr.get('__lr_mult__', 1.0)), + wd_mult=float(sym_attr.get('__wd_mult__', 1.0)), + allow_deferred_init=True, + dtype=arg_type_li[i]) self._reg_params[arg]._var_name = arg - for i, aux in enumerate(aux_params): + for i, aux in enumerate(aux_param_li): if aux in self._reg_params: self._reg_params[aux]._check_and_setattr(grad_req='null', allow_deferred_init=True, - dtype=aux_types[i]) + dtype=aux_type_li[i]) if self._reg_params[aux]._var is None: self._reg_params[aux]._var_name = aux - elif aux not in input_names: + elif aux not in input_name_set: self._reg_params[aux] = Parameter(name=aux, grad_req='null', - allow_deferred_init=True, dtype=aux_types[i]) + allow_deferred_init=True, dtype=aux_type_li[i]) self._reg_params[aux]._var_name = aux self._cached_graph = syms, out diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 33fd48a256a6..952e9cf95c30 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -224,6 +224,26 @@ def test_basic(): model.setattr('grad_req', 'write') assert list(model.collect_params().values())[0]._grad is not None +@use_np +def test_symbol_block_init(): + DTYPE = mx.np.float32 + LR_MULT, WD_MULT = 0.555, 0.444 + svar = mx.symbol.var + s_x = svar('x', shape=(1,256,), dtype=DTYPE) + s_w = svar('W', shape=(256,192), dtype=DTYPE, lr_mult=LR_MULT, wd_mult=WD_MULT) + s_b = svar('b', shape=(1,192,), dtype=DTYPE, init=mx.init.Zero()) + s_y = mx.symbol.linalg.gemm(s_x, s_w, s_b) + + fn = mx.gluon.SymbolBlock([s_y], [s_x]) + fn.initialize() + v_x = mx.nd.random_uniform(-1., 1., shape=(1,256), dtype=DTYPE, + ctx=mx.device.current_device()) + fn.forward(v_x) + param_di = fn.collect_params() + v_w, v_b = param_di['W'], param_di['b'] + assert v_w.lr_mult == LR_MULT + assert v_w.wd_mult == WD_MULT + assert not v_b.data().asnumpy().any() def test_sparse_symbol_block(): data = mx.sym.var('data')