Skip to content
This repository was archived by the owner on Jan 7, 2023. It is now read-only.

Commit c2fa643

Browse files
authored
Merge pull request #294 from ndawe/master
update TMVA examples to use new DataLoader interface
2 parents 8526f80 + 53de3a2 commit c2fa643

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

examples/tmva/plot_multiclass.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,19 @@
3838
factory = TMVA.Factory('classifier', output,
3939
'AnalysisType=Multiclass:'
4040
'!V:Silent:!DrawProgressBar')
41+
42+
data = TMVA.DataLoader('.')
4143
for n in range(2):
42-
factory.AddVariable('f{0}'.format(n), 'F')
44+
data.AddVariable('f{0}'.format(n), 'F')
4345

4446
# Call root_numpy's utility functions to add events from the arrays
45-
add_classification_events(factory, X_train, y_train, weights=w_train)
46-
add_classification_events(factory, X_test, y_test, weights=w_test, test=True)
47+
add_classification_events(data, X_train, y_train, weights=w_train)
48+
add_classification_events(data, X_test, y_test, weights=w_test, test=True)
4749
# The following line is necessary if events have been added individually:
48-
factory.PrepareTrainingAndTestTree(TCut('1'), 'NormMode=EqualNumEvents')
50+
data.PrepareTrainingAndTestTree(TCut('1'), 'NormMode=EqualNumEvents')
4951

5052
# Train an MLP
51-
factory.BookMethod('MLP', 'MLP',
53+
factory.BookMethod(data, 'MLP', 'MLP',
5254
'NeuronType=tanh:NCycles=200:HiddenLayers=N+2,2:'
5355
'TestRate=5:EstimatorType=MSE')
5456
factory.TrainAllMethods()

examples/tmva/plot_regression.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,20 @@
2323
factory = TMVA.Factory('regressor', output,
2424
'AnalysisType=Regression:'
2525
'!V:Silent:!DrawProgressBar')
26-
factory.AddVariable('x', 'F')
27-
factory.AddTarget('y', 'F')
2826

29-
add_regression_events(factory, X, y)
30-
add_regression_events(factory, X, y, test=True)
27+
data = TMVA.DataLoader('.')
28+
data.AddVariable('x', 'F')
29+
data.AddTarget('y', 'F')
30+
31+
add_regression_events(data, X, y)
32+
add_regression_events(data, X, y, test=True)
3133
# The following line is necessary if events have been added individually:
32-
factory.PrepareTrainingAndTestTree(TCut('1'), '')
34+
data.PrepareTrainingAndTestTree(TCut('1'), '')
3335

34-
factory.BookMethod('BDT', 'BDT1',
36+
factory.BookMethod(data, 'BDT', 'BDT1',
3537
'nCuts=20:NTrees=1:MaxDepth=4:BoostType=AdaBoostR2:'
3638
'SeparationType=RegressionVariance')
37-
factory.BookMethod('BDT', 'BDT2',
39+
factory.BookMethod(data, 'BDT', 'BDT2',
3840
'nCuts=20:NTrees=300:MaxDepth=4:BoostType=AdaBoostR2:'
3941
'SeparationType=RegressionVariance')
4042
factory.TrainAllMethods()

examples/tmva/plot_twoclass.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@
3636
factory = TMVA.Factory('classifier', output,
3737
'AnalysisType=Classification:'
3838
'!V:Silent:!DrawProgressBar')
39+
data = TMVA.DataLoader('.')
3940
for n in range(n_vars):
40-
factory.AddVariable('f{0}'.format(n), 'F')
41+
data.AddVariable('f{0}'.format(n), 'F')
4142

4243
# Call root_numpy's utility functions to add events from the arrays
43-
add_classification_events(factory, X_train, y_train, weights=w_train)
44-
add_classification_events(factory, X_test, y_test, weights=w_test, test=True)
44+
add_classification_events(data, X_train, y_train, weights=w_train)
45+
add_classification_events(data, X_test, y_test, weights=w_test, test=True)
4546
# The following line is necessary if events have been added individually:
46-
factory.PrepareTrainingAndTestTree(TCut('1'), 'NormMode=EqualNumEvents')
47+
data.PrepareTrainingAndTestTree(TCut('1'), 'NormMode=EqualNumEvents')
4748

4849
# Train a classifier
49-
factory.BookMethod('Fisher', 'Fisher',
50+
factory.BookMethod(data, 'Fisher', 'Fisher',
5051
'Fisher:VarTransform=None:CreateMVAPdfs:'
5152
'PDFInterpolMVAPdf=Spline2:NbinsMVAPdf=50:'
5253
'NsmoothMVAPdf=10')

0 commit comments

Comments
 (0)