Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 148 additions & 189 deletions docs/precision_checker_guide.md

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#include "infini_train/include/profiler.h"
#endif
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/global_module_hook_registry.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
Expand Down Expand Up @@ -257,6 +259,9 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

const bool last_step = step == FLAGS_num_iteration;

const auto iter_start = std::chrono::high_resolution_clock::now();
Expand Down
5 changes: 5 additions & 0 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/global_module_hook_registry.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
Expand Down Expand Up @@ -232,6 +234,9 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

const bool last_step = step == FLAGS_num_iteration;

const auto iter_start = std::chrono::high_resolution_clock::now();
Expand Down
1 change: 0 additions & 1 deletion infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class Module : public std::enable_shared_from_this<Module> {
std::vector<ModulePostHook> forward_post_hooks_;
std::vector<ModulePreHook> backward_pre_hooks_;
std::vector<ModulePostHook> backward_post_hooks_;
bool precision_check_registered_ = false;

private:
std::unordered_map<std::string, std::shared_ptr<Module>>
Expand Down
41 changes: 41 additions & 0 deletions infini_train/include/utils/global_module_hook_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include "infini_train/include/common/hook.h"
#include <functional>
#include <memory>
#include <mutex>
#include <unordered_set>
#include <vector>

namespace infini_train {
namespace nn {
class Module;
}

namespace utils {

// Global Module Hook Registry
// Manages hooks that need to be applied to all modules
class GlobalModuleHookRegistry {
public:
using ModuleHookRegistrar = std::function<void(nn::Module *)>;

static GlobalModuleHookRegistry &Instance();

// Register a hook registrar, which will be called for all modules on their first forward pass
// Returns a HookHandle that can be used to remove the hook
std::unique_ptr<HookHandle> RegisterHook(ModuleHookRegistrar registrar);

// Apply all registered hooks to the specified module (called by Module::operator())
void ApplyHooks(nn::Module *module);

private:
GlobalModuleHookRegistry() = default;

std::vector<ModuleHookRegistrar> registrars_;
std::unordered_set<nn::Module *> applied_modules_;
mutable std::mutex mutex_;
};

} // namespace utils
} // namespace infini_train
16 changes: 12 additions & 4 deletions infini_train/include/utils/precision_check_config.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <string>
#include <unordered_map>

namespace infini_train {
namespace utils {
Expand All @@ -9,10 +10,11 @@ enum class PrecisionCheckLevel { OFF = 0, MODULE = 1, FUNCTION = 2 };

struct PrecisionCheckConfig {
PrecisionCheckLevel level = PrecisionCheckLevel::OFF;
std::string output_path = ""; // empty=console(rank0), non-empty=file(all ranks)
bool output_md5 = false; // output MD5 hash or tensor values
std::string format = "simple"; // "simple" or "table"
std::string baseline_path = ""; // baseline file path for comparison
std::string output_path = "./log_precision_check"; // Output path (default)
std::string format = "simple"; // "simple" or "md5"
bool save_tensors = false; // Whether to output .npy file
double md5_tolerance = 0.0; // MD5 tolerance for quantization (e.g., 1e-3)
// 0 means no quantization (original precision)

// Parse from "key=value,key=value" string
static PrecisionCheckConfig Parse(const std::string &config_str);
Expand All @@ -23,10 +25,16 @@ class PrecisionCheckEnv {
static PrecisionCheckEnv &Instance();
void Init(const PrecisionCheckConfig &config);
const PrecisionCheckConfig &GetConfig() const;
const std::string &GetOutputPath() const;

// Tensor counter management for file overwrite across iterations (thread-local)
static int GetAndIncrementCounter(const std::string &key);
static void ResetCounters();

private:
PrecisionCheckEnv() = default;
PrecisionCheckConfig config_;
std::string timestamped_path_ = ""; // Actual output path (with timestamp)
};

} // namespace utils
Expand Down
9 changes: 9 additions & 0 deletions infini_train/include/utils/precision_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <string>
#include <vector>

#include "infini_train/include/utils/precision_check_config.h"

namespace infini_train {
class Tensor;
class HookHandle;
Expand Down Expand Up @@ -32,13 +34,20 @@ class PrecisionChecker {
return default_config;
}

// Initialize global module-level precision checking
// Called automatically by PrecisionCheckEnv::Init when level >= MODULE
static void Init(const PrecisionCheckConfig &global_config, const Config &config = DefaultConfig());

static void RegisterForFunction(autograd::Function *func, const std::string &name = "",
const Config &config = DefaultConfig());

// Register hooks for a Module (checks forward inputs/outputs)
static void RegisterForModule(nn::Module *module, const std::string &name = "",
const Config &config = DefaultConfig());

// Reset tensor counters (call at start of each iteration for file overwrite)
static void ResetCounters();

private:
static void CheckTensors(const std::string &stage, const std::string &name,
const std::vector<std::shared_ptr<Tensor>> &tensors, const Config &config);
Expand Down
14 changes: 3 additions & 11 deletions infini_train/src/nn/modules/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#include "infini_train/include/device.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/tensor.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"
#include "infini_train/include/utils/global_module_hook_registry.h"

#ifndef UNLIKELY
#define UNLIKELY(x) __builtin_expect(!!(x), 0)
Expand Down Expand Up @@ -135,15 +134,8 @@ std::vector<std::shared_ptr<Tensor>> Module::Forward(const std::vector<std::shar
}

std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
// Register precision check hooks if enabled and not already registered
// TODO(cx): move RegisterForModule to PrecisionChecker and avoid duplicate registration
if (!precision_check_registered_) {
auto precision_level = utils::PrecisionCheckEnv::Instance().GetConfig().level;
if (precision_level == utils::PrecisionCheckLevel::MODULE) {
utils::PrecisionChecker::RegisterForModule(this);
precision_check_registered_ = true;
}
}
// Apply globally registered hooks (on first call for this module)
utils::GlobalModuleHookRegistry::Instance().ApplyHooks(this);

// Call forward pre-hooks
for (const auto &hook : forward_pre_hooks_) {
Expand Down
29 changes: 29 additions & 0 deletions infini_train/src/utils/global_module_hook_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "infini_train/include/utils/global_module_hook_registry.h"

namespace infini_train::utils {

GlobalModuleHookRegistry &GlobalModuleHookRegistry::Instance() {
static GlobalModuleHookRegistry instance;
return instance;
}

std::unique_ptr<HookHandle> GlobalModuleHookRegistry::RegisterHook(ModuleHookRegistrar registrar) {
std::lock_guard<std::mutex> lock(mutex_);
registrars_.push_back(std::move(registrar));
return std::make_unique<HookHandleImpl<ModuleHookRegistrar>>(&registrars_, registrars_.size() - 1);
}

void GlobalModuleHookRegistry::ApplyHooks(nn::Module *module) {
std::lock_guard<std::mutex> lock(mutex_);
if (applied_modules_.contains(module)) {
return;
}
for (const auto &registrar : registrars_) {
if (registrar) {
registrar(module);
}
}
applied_modules_.insert(module);
}

} // namespace infini_train::utils
57 changes: 45 additions & 12 deletions infini_train/src/utils/precision_check_config.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
#include "infini_train/include/utils/precision_check_config.h"

#include <chrono>
#include <filesystem>
#include <iostream>
#include <sstream>
#include <unordered_map>

#include "infini_train/include/utils/precision_checker.h"

namespace infini_train::utils {

namespace {
// Thread-local tensor counter for precision check file indexing
thread_local std::unordered_map<std::string, int> tls_g_tensor_counter;
} // namespace

PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str) {
PrecisionCheckConfig config;
if (config_str.empty()) {
Expand All @@ -25,20 +35,17 @@ PrecisionCheckConfig PrecisionCheckConfig::Parse(const std::string &config_str)
int level_int = std::stoi(kv_map["level"]);
config.level = static_cast<PrecisionCheckLevel>(level_int);
}
if (kv_map.count("output_path")) {
config.output_path = kv_map["output_path"];
}
if (kv_map.count("output_md5")) {
config.output_md5 = (kv_map["output_md5"] == "true" || kv_map["output_md5"] == "1");
}
if (kv_map.count("baseline")) {
config.baseline_path = kv_map["baseline"];
if (kv_map.count("path")) {
config.output_path = kv_map["path"];
}
if (kv_map.count("format")) {
config.format = kv_map["format"];
} else if (!config.baseline_path.empty()) {
// Default to table format when baseline is specified
config.format = "table";
}
if (kv_map.count("save_tensors")) {
config.save_tensors = (kv_map["save_tensors"] == "true" || kv_map["save_tensors"] == "1");
}
if (kv_map.count("md5_tolerance")) {
config.md5_tolerance = std::stod(kv_map["md5_tolerance"]);
}
return config;
}
Expand All @@ -48,8 +55,34 @@ PrecisionCheckEnv &PrecisionCheckEnv::Instance() {
return instance;
}

void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) { config_ = config; }
void PrecisionCheckEnv::Init(const PrecisionCheckConfig &config) {
config_ = config;
if (config_.level != PrecisionCheckLevel::OFF) {
// Create timestamped subdirectory: output_path/YYYYMMDD_HHMMSS/
auto now = std::chrono::system_clock::now();
auto time_t = std::chrono::system_clock::to_time_t(now);
std::tm tm;
localtime_r(&time_t, &tm);
char buf[32];
std::strftime(buf, sizeof(buf), "%Y%m%d_%H%M%S", &tm);

timestamped_path_ = config_.output_path + "/" + buf;
std::filesystem::create_directories(timestamped_path_);

// Initialize PrecisionChecker (registers global module hooks)
PrecisionChecker::Init(config_);

// Output precision check output path
std::cout << "[PrecisionCheck] Output: " << timestamped_path_ << std::endl;
}
}

const PrecisionCheckConfig &PrecisionCheckEnv::GetConfig() const { return config_; }

const std::string &PrecisionCheckEnv::GetOutputPath() const { return timestamped_path_; }

int PrecisionCheckEnv::GetAndIncrementCounter(const std::string &key) { return tls_g_tensor_counter[key]++; }

void PrecisionCheckEnv::ResetCounters() { tls_g_tensor_counter.clear(); }

} // namespace infini_train::utils
4 changes: 2 additions & 2 deletions infini_train/src/utils/precision_check_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
namespace infini_train::utils {

PrecisionCheckContext &PrecisionCheckContext::Instance() {
static thread_local PrecisionCheckContext instance;
return instance;
static thread_local PrecisionCheckContext tls_instance;
return tls_instance;
}

void PrecisionCheckContext::SetGAS(int gas) { gas_ = gas; }
Expand Down
Loading