Skip to content

Commit 2c4213b

Browse files
Merge pull request #74 from oslabs-beta/yiqun/new
feat: add NN multi run example
2 parents d82bfbb + 11b01f2 commit 2c4213b

File tree

1 file changed

+346
-0
lines changed

1 file changed

+346
-0
lines changed
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
import * as tf from '@tensorflow/tfjs-node';
2+
import MLflow from 'mlflow-js';
3+
import { fileURLToPath } from 'url';
4+
import { dirname } from 'path';
5+
6+
const mlflow = new MLflow('http://localhost:5001');
7+
8+
const HYPERPARAMETER_SPACE = {
9+
networkArchitectures: [
10+
[16, 8], // Small network
11+
[32, 16], // Medium network
12+
[64, 32], // Larger network
13+
],
14+
learningRates: [0.001, 0.01],
15+
batchSizes: [32, 64],
16+
dropoutRates: [0, 0.2],
17+
};
18+
19+
const TRAINING_CONFIG = {
20+
epochs: 20,
21+
validationSplit: 0.2,
22+
earlyStoppingPatience: 3,
23+
datasetSize: 2000,
24+
inputFeatures: 5,
25+
outputClasses: 3,
26+
minibatchSize: 128, // Added for faster training
27+
};
28+
29+
// Data generation
30+
function generateData() {
31+
return tf.tidy(() => {
32+
const x = tf.randomNormal([
33+
TRAINING_CONFIG.datasetSize,
34+
TRAINING_CONFIG.inputFeatures,
35+
]);
36+
37+
const weights = tf.randomNormal([
38+
TRAINING_CONFIG.inputFeatures,
39+
TRAINING_CONFIG.outputClasses,
40+
]);
41+
const logits = x.matMul(weights);
42+
const y = tf.softmax(logits);
43+
44+
// Split into train and validation sets
45+
const splitIdx = Math.floor(TRAINING_CONFIG.datasetSize * 0.8);
46+
47+
return {
48+
trainX: x.slice([0, 0], [splitIdx, -1]),
49+
trainY: y.slice([0, 0], [splitIdx, -1]),
50+
testX: x.slice([splitIdx, 0], [-1, -1]),
51+
testY: y.slice([splitIdx, 0], [-1, -1]),
52+
};
53+
});
54+
}
55+
56+
// Model creation
57+
function createModel(architecture, learningRate, dropoutRate) {
58+
const model = tf.sequential();
59+
60+
// Input layer
61+
model.add(
62+
tf.layers.dense({
63+
units: architecture[0],
64+
inputShape: [TRAINING_CONFIG.inputFeatures],
65+
activation: 'relu',
66+
})
67+
);
68+
69+
if (dropoutRate > 0) {
70+
model.add(tf.layers.dropout({ rate: dropoutRate }));
71+
}
72+
73+
// Hidden layers
74+
for (let i = 1; i < architecture.length; i++) {
75+
model.add(
76+
tf.layers.dense({
77+
units: architecture[i],
78+
activation: 'relu',
79+
})
80+
);
81+
}
82+
83+
// Output layer
84+
model.add(
85+
tf.layers.dense({
86+
units: TRAINING_CONFIG.outputClasses,
87+
activation: 'softmax',
88+
})
89+
);
90+
91+
model.compile({
92+
optimizer: tf.train.adam(learningRate),
93+
loss: 'categoricalCrossentropy',
94+
metrics: ['accuracy'],
95+
});
96+
97+
return model;
98+
}
99+
100+
class MLflowCallback extends tf.Callback {
101+
constructor(runId) {
102+
super();
103+
this.runId = runId;
104+
this.batchesLogged = 0;
105+
this.logInterval = 2; // Log every 2 epochs to reduce overhead
106+
}
107+
108+
async onEpochEnd(epoch, logs) {
109+
if (
110+
epoch % this.logInterval === 0 ||
111+
epoch === TRAINING_CONFIG.epochs - 1
112+
) {
113+
const metrics = [
114+
{
115+
key: 'train_loss',
116+
value: logs.loss,
117+
timestamp: Date.now(),
118+
step: epoch,
119+
},
120+
{
121+
key: 'train_accuracy',
122+
value: logs.acc,
123+
timestamp: Date.now(),
124+
step: epoch,
125+
},
126+
{
127+
key: 'val_loss',
128+
value: logs.val_loss,
129+
timestamp: Date.now(),
130+
step: epoch,
131+
},
132+
{
133+
key: 'val_accuracy',
134+
value: logs.val_acc,
135+
timestamp: Date.now(),
136+
step: epoch,
137+
},
138+
];
139+
await mlflow.logBatch(this.runId, metrics);
140+
}
141+
}
142+
}
143+
144+
async function trainModel(model, trainX, trainY, valX, valY, runId, batchSize) {
145+
return await model.fit(trainX, trainY, {
146+
epochs: TRAINING_CONFIG.epochs,
147+
batchSize: batchSize,
148+
validationData: [valX, valY],
149+
callbacks: [
150+
tf.callbacks.earlyStopping({
151+
monitor: 'val_loss',
152+
patience: TRAINING_CONFIG.earlyStoppingPatience,
153+
}),
154+
new MLflowCallback(runId),
155+
],
156+
shuffle: true,
157+
});
158+
}
159+
160+
function evaluateModel(model, testX, testY) {
161+
return tf.tidy(() => {
162+
const evaluation = model.evaluate(testX, testY);
163+
const predictions = model.predict(testX);
164+
165+
const confusionMatrix = tf.math.confusionMatrix(
166+
tf.argMax(testY, 1),
167+
tf.argMax(predictions, 1),
168+
TRAINING_CONFIG.outputClasses
169+
);
170+
171+
return {
172+
testLoss: evaluation[0].dataSync()[0],
173+
testAccuracy: evaluation[1].dataSync()[0],
174+
confusionMatrix: confusionMatrix.arraySync(),
175+
};
176+
});
177+
}
178+
179+
async function runExperiment(experimentId, hyperparams, data) {
180+
const runName = `NN-${hyperparams.architecture.join('-')}-lr${
181+
hyperparams.learningRate
182+
}`;
183+
const run = await mlflow.createRun(experimentId, runName);
184+
const runId = run.info.run_id;
185+
186+
try {
187+
// Log hyperparameters
188+
const params = [
189+
{ key: 'architecture', value: hyperparams.architecture.join(',') },
190+
{ key: 'learning_rate', value: hyperparams.learningRate.toString() },
191+
{ key: 'batch_size', value: hyperparams.batchSize.toString() },
192+
{ key: 'dropout_rate', value: hyperparams.dropoutRate.toString() },
193+
];
194+
await mlflow.logBatch(runId, undefined, params);
195+
196+
const model = createModel(
197+
hyperparams.architecture,
198+
hyperparams.learningRate,
199+
hyperparams.dropoutRate
200+
);
201+
202+
await trainModel(
203+
model,
204+
data.trainX,
205+
data.trainY,
206+
data.testX,
207+
data.testY,
208+
runId,
209+
hyperparams.batchSize
210+
);
211+
212+
const evaluation = evaluateModel(model, data.testX, data.testY);
213+
214+
const finalMetrics = [
215+
{ key: 'test_loss', value: evaluation.testLoss, timestamp: Date.now() },
216+
{
217+
key: 'test_accuracy',
218+
value: evaluation.testAccuracy,
219+
timestamp: Date.now(),
220+
},
221+
];
222+
await mlflow.logBatch(runId, finalMetrics);
223+
224+
const tags = [
225+
{
226+
key: 'confusion_matrix',
227+
value: JSON.stringify(evaluation.confusionMatrix),
228+
},
229+
];
230+
await mlflow.logBatch(runId, undefined, undefined, tags);
231+
232+
// Save model artifacts
233+
const __filename = fileURLToPath(import.meta.url);
234+
const __dirname = dirname(__filename);
235+
const artifactsPath = `${__dirname}/../mlruns/${experimentId}/${runId}/artifacts`;
236+
await model.save(`file://${artifactsPath}/model`);
237+
238+
await mlflow.updateRun(runId, 'FINISHED');
239+
240+
return {
241+
runId,
242+
metrics: evaluation,
243+
};
244+
} catch (error) {
245+
console.error(`Error in run ${runId}:`, error);
246+
await mlflow.updateRun(runId, 'FAILED');
247+
throw error;
248+
}
249+
}
250+
251+
async function main() {
252+
try {
253+
console.time('Total Execution Time');
254+
255+
const experimentName = 'Neural_Network_Hyperparameter_Tuning_Fast';
256+
let experimentId;
257+
try {
258+
const experiment = await mlflow.getExperimentByName(experimentName);
259+
experimentId = experiment.experiment_id;
260+
} catch {
261+
experimentId = await mlflow.createExperiment(experimentName);
262+
}
263+
console.log(`MLflow Experiment ID: ${experimentId}`);
264+
265+
console.time('Data Generation');
266+
const data = generateData();
267+
console.timeEnd('Data Generation');
268+
269+
const results = [];
270+
let totalRuns = 0;
271+
const maxRuns =
272+
HYPERPARAMETER_SPACE.networkArchitectures.length *
273+
HYPERPARAMETER_SPACE.learningRates.length *
274+
HYPERPARAMETER_SPACE.batchSizes.length *
275+
HYPERPARAMETER_SPACE.dropoutRates.length;
276+
277+
console.log(`\nStarting ${maxRuns} training runs...`);
278+
279+
for (const architecture of HYPERPARAMETER_SPACE.networkArchitectures) {
280+
for (const learningRate of HYPERPARAMETER_SPACE.learningRates) {
281+
for (const batchSize of HYPERPARAMETER_SPACE.batchSizes) {
282+
for (const dropoutRate of HYPERPARAMETER_SPACE.dropoutRates) {
283+
totalRuns++;
284+
console.time(`Run ${totalRuns}`);
285+
286+
const hyperparams = {
287+
architecture,
288+
learningRate,
289+
batchSize,
290+
dropoutRate,
291+
};
292+
293+
console.log(`\nRun ${totalRuns}/${maxRuns}:`, hyperparams);
294+
295+
const result = await runExperiment(experimentId, hyperparams, data);
296+
results.push(result);
297+
298+
console.log(`Accuracy: ${result.metrics.testAccuracy.toFixed(4)}`);
299+
console.timeEnd(`Run ${totalRuns}`);
300+
}
301+
}
302+
}
303+
}
304+
305+
const bestRun = results.reduce((best, current) => {
306+
return current.metrics.testAccuracy > best.metrics.testAccuracy
307+
? current
308+
: best;
309+
});
310+
311+
console.log('\nBest performing run:', bestRun.runId);
312+
console.log('Test accuracy:', bestRun.metrics.testAccuracy);
313+
314+
// Register best model if accuracy is good enough
315+
if (bestRun.metrics.testAccuracy > 0.8) {
316+
const modelName = 'NeuralNetworkClassifier_Fast';
317+
try {
318+
await mlflow.createRegisteredModel(
319+
modelName,
320+
[{ key: 'task', value: 'classification' }],
321+
'Optimized neural network classifier'
322+
);
323+
324+
const modelVersion = await mlflow.createModelVersion(
325+
modelName,
326+
`runs:/${bestRun.runId}/model`,
327+
bestRun.runId,
328+
[{ key: 'accuracy', value: bestRun.metrics.testAccuracy.toString() }]
329+
);
330+
} catch (e) {
331+
console.error('Model registration error:', e.message);
332+
}
333+
}
334+
335+
tf.dispose([data.trainX, data.trainY, data.testX, data.testY]);
336+
337+
console.timeEnd('Total Execution Time');
338+
console.log(
339+
`\nView results at http://localhost:5001/#/experiments/${experimentId}`
340+
);
341+
} catch (error) {
342+
console.error('Experiment failed:', error);
343+
}
344+
}
345+
346+
main();

0 commit comments

Comments
 (0)