1111//
1212
1313using System ;
14+ using System . Linq ;
1415using System . Collections . Generic ;
1516using System . Management . Automation . Language ;
1617using Microsoft . Windows . PowerShell . ScriptAnalyzer . Generic ;
@@ -33,17 +34,67 @@ public class PossibleIncorrectComparisonWithNull : IScriptRule {
3334 public IEnumerable < DiagnosticRecord > AnalyzeScript ( Ast ast , string fileName ) {
3435 if ( ast == null ) throw new ArgumentNullException ( Strings . NullAstErrorMessage ) ;
3536
36- IEnumerable < Ast > binExpressionAsts = ast . FindAll ( testAst => testAst is BinaryExpressionAst , true ) ;
37+ IEnumerable < Ast > binExpressionAsts = ast . FindAll ( testAst => testAst is BinaryExpressionAst , false ) ;
3738
38- if ( binExpressionAsts != null ) {
39- foreach ( BinaryExpressionAst binExpressionAst in binExpressionAsts ) {
40- if ( ( binExpressionAst . Operator . Equals ( TokenKind . Equals ) || binExpressionAst . Operator . Equals ( TokenKind . Ceq )
41- || binExpressionAst . Operator . Equals ( TokenKind . Cne ) || binExpressionAst . Operator . Equals ( TokenKind . Ine ) || binExpressionAst . Operator . Equals ( TokenKind . Ieq ) )
42- && binExpressionAst . Right . Extent . Text . Equals ( "$null" , StringComparison . OrdinalIgnoreCase ) ) {
43- yield return new DiagnosticRecord ( Strings . PossibleIncorrectComparisonWithNullError , binExpressionAst . Extent , GetName ( ) , DiagnosticSeverity . Warning , fileName ) ;
39+ foreach ( BinaryExpressionAst binExpressionAst in binExpressionAsts ) {
40+ if ( ( binExpressionAst . Operator . Equals ( TokenKind . Equals ) || binExpressionAst . Operator . Equals ( TokenKind . Ceq )
41+ || binExpressionAst . Operator . Equals ( TokenKind . Cne ) || binExpressionAst . Operator . Equals ( TokenKind . Ine ) || binExpressionAst . Operator . Equals ( TokenKind . Ieq ) )
42+ && binExpressionAst . Right . Extent . Text . Equals ( "$null" , StringComparison . OrdinalIgnoreCase ) )
43+ {
44+ if ( IncorrectComparisonWithNull ( binExpressionAst , ast ) )
45+ {
46+ yield return new DiagnosticRecord ( Strings . PossibleIncorrectComparisonWithNullError , binExpressionAst . Extent , GetName ( ) , DiagnosticSeverity . Warning , fileName ) ;
4447 }
4548 }
4649 }
50+
51+ IEnumerable < Ast > funcAsts = ast . FindAll ( item => item is FunctionDefinitionAst , true ) . Union ( ast . FindAll ( item => item is FunctionMemberAst , true ) ) ;
52+ foreach ( Ast funcAst in funcAsts )
53+ {
54+ IEnumerable < Ast > binAsts = funcAst . FindAll ( item => item is BinaryExpressionAst , true ) ;
55+ foreach ( BinaryExpressionAst binAst in binAsts )
56+ {
57+ if ( IncorrectComparisonWithNull ( binAst , funcAst ) )
58+ {
59+ yield return new DiagnosticRecord ( Strings . PossibleIncorrectComparisonWithNullError , binAst . Extent , GetName ( ) , DiagnosticSeverity . Warning , fileName ) ;
60+ }
61+ }
62+ }
63+ }
64+
65+ private bool IncorrectComparisonWithNull ( BinaryExpressionAst binExpressionAst , Ast ast )
66+ {
67+ if ( ( binExpressionAst . Operator . Equals ( TokenKind . Equals ) || binExpressionAst . Operator . Equals ( TokenKind . Ceq )
68+ || binExpressionAst . Operator . Equals ( TokenKind . Cne ) || binExpressionAst . Operator . Equals ( TokenKind . Ine ) || binExpressionAst . Operator . Equals ( TokenKind . Ieq ) )
69+ && binExpressionAst . Right . Extent . Text . Equals ( "$null" , StringComparison . OrdinalIgnoreCase ) )
70+ {
71+ if ( binExpressionAst . Left . StaticType . IsArray )
72+ {
73+ return true ;
74+ }
75+ else if ( binExpressionAst . Left is VariableExpressionAst )
76+ {
77+ // ignores if the variable is a special variable
78+ if ( ! Helper . Instance . HasSpecialVars ( ( binExpressionAst . Left as VariableExpressionAst ) . VariablePath . UserPath ) )
79+ {
80+ Type lhsType = Helper . Instance . GetTypeFromAnalysis ( binExpressionAst . Left as VariableExpressionAst , ast ) ;
81+ if ( lhsType == null )
82+ {
83+ return true ;
84+ }
85+ else if ( lhsType . IsArray || lhsType == typeof ( object ) || lhsType == typeof ( Undetermined ) || lhsType == typeof ( Unreached ) )
86+ {
87+ return true ;
88+ }
89+ }
90+ }
91+ else if ( binExpressionAst . Left . StaticType == typeof ( object ) )
92+ {
93+ return true ;
94+ }
95+ }
96+
97+ return false ;
4798 }
4899
49100 /// <summary>
0 commit comments