Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hessian/model_debugger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion hessian/model_debugger_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion hessian/precondition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion hessian/test_model_debugger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion hessian/test_precondition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/base_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/autoaugment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/criteo_terabyte_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/data_selectors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
209 changes: 132 additions & 77 deletions init2winit/dataset_lib/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,6 +31,7 @@
from init2winit.dataset_lib import nanodo_fineweb_edu
from init2winit.dataset_lib import nqm_noise
from init2winit.dataset_lib import ogbg_molpcba
from init2winit.dataset_lib import ogbg_molpcba_preprocessed
from init2winit.dataset_lib import proteins
from init2winit.dataset_lib import small_image_datasets
from init2winit.dataset_lib import translate_wmt
Expand All @@ -42,90 +43,145 @@
'Dataset', ('getter', 'hparams', 'meta_data', 'fake_batch_getter'))

_ALL_DATASETS = {
'mnist':
_Dataset(small_image_datasets.get_mnist,
small_image_datasets.MNIST_HPARAMS,
small_image_datasets.MNIST_METADATA, None),
'mnist_autoencoder':
_Dataset(small_image_datasets.get_mnist_autoencoder,
small_image_datasets.MNIST_AUTOENCODER_HPARAMS,
small_image_datasets.MNIST_AUTOENCODER_METADATA, None),
'fashion_mnist':
_Dataset(small_image_datasets.get_fashion_mnist,
small_image_datasets.FASHION_MNIST_HPARAMS,
small_image_datasets.FASHION_MNIST_METADATA, None),
'cifar10':
_Dataset(small_image_datasets.get_cifar10,
small_image_datasets.CIFAR10_DEFAULT_HPARAMS,
small_image_datasets.CIFAR10_METADATA, None),
'cifar100':
_Dataset(small_image_datasets.get_cifar100,
small_image_datasets.CIFAR100_DEFAULT_HPARAMS,
small_image_datasets.CIFAR100_METADATA, None),
'criteo1tb':
_Dataset(criteo_terabyte_dataset.get_criteo1tb,
criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS,
criteo_terabyte_dataset.CRITEO1TB_METADATA,
criteo_terabyte_dataset.get_fake_batch),
'fake':
_Dataset(fake_dataset.get_fake, fake_dataset.DEFAULT_HPARAMS,
fake_dataset.METADATA, fake_dataset.get_fake_batch),
'fastmri':
_Dataset(fastmri_dataset.get_fastmri, fastmri_dataset.DEFAULT_HPARAMS,
fastmri_dataset.METADATA, fastmri_dataset.get_fake_batch),
'mnist': _Dataset(
small_image_datasets.get_mnist,
small_image_datasets.MNIST_HPARAMS,
small_image_datasets.MNIST_METADATA,
None,
),
'mnist_autoencoder': _Dataset(
small_image_datasets.get_mnist_autoencoder,
small_image_datasets.MNIST_AUTOENCODER_HPARAMS,
small_image_datasets.MNIST_AUTOENCODER_METADATA,
None,
),
'fashion_mnist': _Dataset(
small_image_datasets.get_fashion_mnist,
small_image_datasets.FASHION_MNIST_HPARAMS,
small_image_datasets.FASHION_MNIST_METADATA,
None,
),
'cifar10': _Dataset(
small_image_datasets.get_cifar10,
small_image_datasets.CIFAR10_DEFAULT_HPARAMS,
small_image_datasets.CIFAR10_METADATA,
None,
),
'cifar100': _Dataset(
small_image_datasets.get_cifar100,
small_image_datasets.CIFAR100_DEFAULT_HPARAMS,
small_image_datasets.CIFAR100_METADATA,
None,
),
'criteo1tb': _Dataset(
criteo_terabyte_dataset.get_criteo1tb,
criteo_terabyte_dataset.CRITEO1TB_DEFAULT_HPARAMS,
criteo_terabyte_dataset.CRITEO1TB_METADATA,
criteo_terabyte_dataset.get_fake_batch,
),
'fake': _Dataset(
fake_dataset.get_fake,
fake_dataset.DEFAULT_HPARAMS,
fake_dataset.METADATA,
fake_dataset.get_fake_batch,
),
'fastmri': _Dataset(
fastmri_dataset.get_fastmri,
fastmri_dataset.DEFAULT_HPARAMS,
fastmri_dataset.METADATA,
fastmri_dataset.get_fake_batch,
),
'fineweb_edu_10B': _Dataset(
fineweb_edu_10b.get_fineweb_edu,
fineweb_edu_10b.DEFAULT_HPARAMS,
fineweb_edu_10b.METADATA, None),
'imagenet':
_Dataset(imagenet_dataset.get_imagenet,
imagenet_dataset.DEFAULT_HPARAMS, imagenet_dataset.METADATA,
imagenet_dataset.get_fake_batch),
'translate_wmt':
_Dataset(translate_wmt.get_translate_wmt, translate_wmt.DEFAULT_HPARAMS,
translate_wmt.METADATA, translate_wmt.get_fake_batch),
'librispeech':
_Dataset(librispeech.get_librispeech, librispeech.DEFAULT_HPARAMS,
librispeech.METADATA, librispeech.get_fake_batch),
'lm1b_v2':
_Dataset(lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA,
None),
'mlperf_imagenet':
_Dataset(mlperf_imagenet_dataset.get_mlperf_imagenet,
mlperf_imagenet_dataset.DEFAULT_HPARAMS,
mlperf_imagenet_dataset.METADATA,
mlperf_imagenet_dataset.get_fake_batch),
'svhn_no_extra':
_Dataset(small_image_datasets.get_svhn_no_extra,
small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS,
small_image_datasets.SVHN_NO_EXTRA_METADATA, None),
fineweb_edu_10b.METADATA,
None,
),
'imagenet': _Dataset(
imagenet_dataset.get_imagenet,
imagenet_dataset.DEFAULT_HPARAMS,
imagenet_dataset.METADATA,
imagenet_dataset.get_fake_batch,
),
'translate_wmt': _Dataset(
translate_wmt.get_translate_wmt,
translate_wmt.DEFAULT_HPARAMS,
translate_wmt.METADATA,
translate_wmt.get_fake_batch,
),
'librispeech': _Dataset(
librispeech.get_librispeech,
librispeech.DEFAULT_HPARAMS,
librispeech.METADATA,
librispeech.get_fake_batch,
),
'lm1b_v2': _Dataset(
lm1b_v2.get_lm1b, lm1b_v2.DEFAULT_HPARAMS, lm1b_v2.METADATA, None
),
'mlperf_imagenet': _Dataset(
mlperf_imagenet_dataset.get_mlperf_imagenet,
mlperf_imagenet_dataset.DEFAULT_HPARAMS,
mlperf_imagenet_dataset.METADATA,
mlperf_imagenet_dataset.get_fake_batch,
),
'svhn_no_extra': _Dataset(
small_image_datasets.get_svhn_no_extra,
small_image_datasets.SVHN_NO_EXTRA_DEFAULT_HPARAMS,
small_image_datasets.SVHN_NO_EXTRA_METADATA,
None,
),
'c4': _Dataset(
nanodo_c4.get_dataset,
nanodo_c4.DEFAULT_HPARAMS,
nanodo_c4.METADATA, None),
nanodo_c4.METADATA,
None,
),
'fineweb_edu': _Dataset(
nanodo_fineweb_edu.get_dataset,
nanodo_fineweb_edu.DEFAULT_HPARAMS,
nanodo_fineweb_edu.METADATA, None),
'nqm_noise':
_Dataset(nqm_noise.get_nqm_noise, nqm_noise.NQM_HPARAMS,
nqm_noise.NQM_METADATA, None),
'ogbg_molpcba':
_Dataset(ogbg_molpcba.get_ogbg_molpcba, ogbg_molpcba.DEFAULT_HPARAMS,
ogbg_molpcba.METADATA, ogbg_molpcba.get_fake_batch),
'uniref50':
_Dataset(proteins.get_uniref, proteins.DEFAULT_HPARAMS,
proteins.METADATA, None),
'wikitext2':
_Dataset(wikitext2.get_wikitext2, wikitext2.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
'wikitext103':
_Dataset(wikitext103.get_wikitext103, wikitext103.DEFAULT_HPARAMS,
wikitext2.METADATA, None),
'wikitext103_spm':
_Dataset(wikitext103_spm.get_wikitext103,
wikitext103_spm.DEFAULT_HPARAMS,
wikitext103_spm.METADATA, None),
nanodo_fineweb_edu.METADATA,
None,
),
'nqm_noise': _Dataset(
nqm_noise.get_nqm_noise,
nqm_noise.NQM_HPARAMS,
nqm_noise.NQM_METADATA,
None,
),
'ogbg_molpcba': _Dataset(
ogbg_molpcba.get_ogbg_molpcba,
ogbg_molpcba.DEFAULT_HPARAMS,
ogbg_molpcba.METADATA,
ogbg_molpcba.get_fake_batch,
),
'ogbg_molpcba_preprocessed': _Dataset(
ogbg_molpcba_preprocessed.get_ogbg_molpcba_preprocessed,
ogbg_molpcba_preprocessed.DEFAULT_HPARAMS,
ogbg_molpcba_preprocessed.METADATA,
# Reuse logic for fake batch if needed, or None
ogbg_molpcba.get_fake_batch,
),
'uniref50': _Dataset(
proteins.get_uniref, proteins.DEFAULT_HPARAMS, proteins.METADATA, None
),
'wikitext2': _Dataset(
wikitext2.get_wikitext2,
wikitext2.DEFAULT_HPARAMS,
wikitext2.METADATA,
None,
),
'wikitext103': _Dataset(
wikitext103.get_wikitext103,
wikitext103.DEFAULT_HPARAMS,
wikitext2.METADATA,
None,
),
'wikitext103_spm': _Dataset(
wikitext103_spm.get_wikitext103,
wikitext103_spm.DEFAULT_HPARAMS,
wikitext103_spm.METADATA,
None,
),
}


Expand Down Expand Up @@ -238,4 +294,3 @@ def get_data_selector(selector_name: Optional[str]):
'Unrecognized selector: {}'.format(selector_name)) from None

return selector

2 changes: 1 addition & 1 deletion init2winit/dataset_lib/fake_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/fastmri_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/fineweb_edu_10b.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/image_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/imagenet_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion init2winit/dataset_lib/librispeech.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2025 The init2winit Authors.
# Copyright 2026 The init2winit Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading