@@ -213,6 +213,16 @@ fn type_and_tlv_indices<S: BaseState>(
213213 }
214214}
215215
216+ /// Checks a base buffer to verify if it is an Account without having to completely deserialize it
217+ fn is_initialized_account ( input : & [ u8 ] ) -> Result < bool , ProgramError > {
218+ const ACCOUNT_INITIALIZED_INDEX : usize = 108 ; // See state.rs#L99
219+
220+ if input. len ( ) != BASE_ACCOUNT_LENGTH {
221+ return Err ( ProgramError :: InvalidAccountData ) ;
222+ }
223+ Ok ( input[ ACCOUNT_INITIALIZED_INDEX ] != 0 )
224+ }
225+
216226fn get_extension < S : BaseState , V : Extension > ( tlv_data : & [ u8 ] ) -> Result < & V , ProgramError > {
217227 if V :: TYPE . get_account_type ( ) != S :: ACCOUNT_TYPE {
218228 return Err ( ProgramError :: InvalidAccountData ) ;
@@ -524,9 +534,13 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
524534
525535/// If AccountType is uninitialized, set it to the BaseState's ACCOUNT_TYPE;
526536/// if AccountType is already set, check is set correctly for BaseState
537+ /// This method assumes that the `base_data` has already been packed with data of the desired type.
527538pub fn set_account_type < S : BaseState > ( input : & mut [ u8 ] ) -> Result < ( ) , ProgramError > {
528539 check_min_len_and_not_multisig ( input, S :: LEN ) ?;
529- let ( _base_data, rest) = input. split_at_mut ( S :: LEN ) ;
540+ let ( base_data, rest) = input. split_at_mut ( S :: LEN ) ;
541+ if S :: ACCOUNT_TYPE == AccountType :: Account && !is_initialized_account ( base_data) ? {
542+ return Err ( ProgramError :: InvalidAccountData ) ;
543+ }
530544 if let Some ( ( account_type_index, _tlv_start_index) ) = type_and_tlv_indices :: < S > ( rest) ? {
531545 let mut account_type = AccountType :: try_from ( rest[ account_type_index] )
532546 . map_err ( |_| ProgramError :: InvalidAccountData ) ?;
@@ -1389,6 +1403,22 @@ mod test {
13891403 assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
13901404 }
13911405
1406+ #[ test]
1407+ fn test_set_account_type_wrongly ( ) {
1408+ // try to set Account account_type to Mint
1409+ let mut buffer = TEST_ACCOUNT_SLICE . to_vec ( ) ;
1410+ buffer. append ( & mut vec ! [ 0 ; 2 ] ) ;
1411+ let err = set_account_type :: < Mint > ( & mut buffer) . unwrap_err ( ) ;
1412+ assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
1413+
1414+ // try to set Mint account_type to Account
1415+ let mut buffer = TEST_MINT_SLICE . to_vec ( ) ;
1416+ buffer. append ( & mut vec ! [ 0 ; Account :: LEN - Mint :: LEN ] ) ;
1417+ buffer. append ( & mut vec ! [ 0 ; 2 ] ) ;
1418+ let err = set_account_type :: < Account > ( & mut buffer) . unwrap_err ( ) ;
1419+ assert_eq ! ( err, ProgramError :: InvalidAccountData ) ;
1420+ }
1421+
13921422 #[ test]
13931423 fn test_get_required_init_account_extensions ( ) {
13941424 // Some mint extensions with no required account extensions
0 commit comments