inetAds = ni.getInetAddresses();
+ while (inetAds.hasMoreElements()) {
+ InetAddress inetAddress = inetAds.nextElement();
+ if (inetAddress instanceof Inet4Address && !isReservedAddress(inetAddress)) {
+ String ipAddress = inetAddress.getHostAddress();
+ if (PATTERN_IPv4.matcher(ipAddress).find()) {
+ return ipAddress;
+ }
+ }
+ }
+ }
+ logger.warn("Can't get lan IP. Fall back to {}", IPADDRESS_LOCALHOST);
+ return IPADDRESS_LOCALHOST;
+ }
+}
diff --git a/p2p/src/main/java/org/tron/p2p/utils/ProtoUtil.java b/p2p/src/main/java/org/tron/p2p/utils/ProtoUtil.java
new file mode 100644
index 00000000000..bb3db746ebc
--- /dev/null
+++ b/p2p/src/main/java/org/tron/p2p/utils/ProtoUtil.java
@@ -0,0 +1,48 @@
+package org.tron.p2p.utils;
+
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.exception.P2pException;
+import org.tron.p2p.protos.Connect;
+import org.xerial.snappy.Snappy;
+
+public class ProtoUtil {
+
+ public static Connect.CompressMessage compressMessage(byte[] data) throws IOException {
+ Connect.CompressMessage.CompressType type = Connect.CompressMessage.CompressType.uncompress;
+ byte[] bytes = data;
+
+ byte[] compressData = Snappy.compress(data);
+ if (compressData.length < bytes.length) {
+ type = Connect.CompressMessage.CompressType.snappy;
+ bytes = compressData;
+ }
+
+ return Connect.CompressMessage.newBuilder()
+ .setData(ByteString.copyFrom(bytes))
+ .setType(type)
+ .build();
+ }
+
+ public static byte[] uncompressMessage(Connect.CompressMessage message)
+ throws IOException, P2pException {
+ byte[] data = message.getData().toByteArray();
+ if (message.getType().equals(Connect.CompressMessage.CompressType.uncompress)) {
+ return data;
+ }
+
+ int length = Snappy.uncompressedLength(data);
+ if (length >= Parameter.MAX_MESSAGE_LENGTH) {
+ throw new P2pException(
+ P2pException.TypeEnum.BIG_MESSAGE, "message is too big, len=" + length);
+ }
+
+ byte[] d2 = Snappy.uncompress(data);
+ if (d2.length >= Parameter.MAX_MESSAGE_LENGTH) {
+ throw new P2pException(
+ P2pException.TypeEnum.BIG_MESSAGE, "uncompressed is too big, len=" + length);
+ }
+ return d2;
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/crypto/ECDSASignature.java b/p2p/src/main/java/org/web3j/crypto/ECDSASignature.java
new file mode 100644
index 00000000000..864651c85bb
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/crypto/ECDSASignature.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.crypto;
+
+import java.math.BigInteger;
+
+/** An ECDSA Signature. */
+public class ECDSASignature {
+ public final BigInteger r;
+ public final BigInteger s;
+
+ public ECDSASignature(BigInteger r, BigInteger s) {
+ this.r = r;
+ this.s = s;
+ }
+
+ /**
+ * @return true if the S component is "low", that means it is below {@link Sign#HALF_CURVE_ORDER}.
+ * See
+ * BIP62.
+ */
+ public boolean isCanonical() {
+ return s.compareTo(Sign.HALF_CURVE_ORDER) <= 0;
+ }
+
+ /**
+ * Will automatically adjust the S component to be less than or equal to half the curve order, if
+ * necessary. This is required because for every signature (r,s) the signature (r, -s (mod N)) is
+ * a valid signature of the same message. However, we dislike the ability to modify the bits of a
+ * Bitcoin transaction after it's been signed, as that violates various assumed invariants. Thus
+ * in future only one of those forms will be considered legal and the other will be banned.
+ *
+ * @return the signature in a canonicalised form.
+ */
+ public ECDSASignature toCanonicalised() {
+ if (!isCanonical()) {
+ // The order of the curve is the number of valid points that exist on that curve.
+ // If S is in the upper half of the number of valid points, then bring it back to
+ // the lower half. Otherwise, imagine that
+ // N = 10
+ // s = 8, so (-8 % 10 == 2) thus both (r, 8) and (r, 2) are valid solutions.
+ // 10 - 8 == 2, giving us always the latter solution, which is canonical.
+ return new ECDSASignature(r, Sign.CURVE.getN().subtract(s));
+ } else {
+ return this;
+ }
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/crypto/ECKeyPair.java b/p2p/src/main/java/org/web3j/crypto/ECKeyPair.java
new file mode 100644
index 00000000000..357cecca5e3
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/crypto/ECKeyPair.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.crypto;
+
+import java.math.BigInteger;
+import java.security.KeyPair;
+import java.util.Arrays;
+import org.bouncycastle.crypto.digests.SHA256Digest;
+import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
+import org.bouncycastle.crypto.signers.ECDSASigner;
+import org.bouncycastle.crypto.signers.HMacDSAKCalculator;
+import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey;
+import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
+import org.web3j.utils.Numeric;
+
+/** Elliptic Curve SECP-256k1 generated key pair. */
+public class ECKeyPair {
+ private final BigInteger privateKey;
+ private final BigInteger publicKey;
+
+ public ECKeyPair(BigInteger privateKey, BigInteger publicKey) {
+ this.privateKey = privateKey;
+ this.publicKey = publicKey;
+ }
+
+ public BigInteger getPrivateKey() {
+ return privateKey;
+ }
+
+ public BigInteger getPublicKey() {
+ return publicKey;
+ }
+
+ /**
+ * Sign a hash with the private key of this key pair.
+ *
+ * @param transactionHash the hash to sign
+ * @return An {@link ECDSASignature} of the hash
+ */
+ public ECDSASignature sign(byte[] transactionHash) {
+ ECDSASigner signer = new ECDSASigner(new HMacDSAKCalculator(new SHA256Digest()));
+
+ ECPrivateKeyParameters privKey = new ECPrivateKeyParameters(privateKey, Sign.CURVE);
+ signer.init(true, privKey);
+ BigInteger[] components = signer.generateSignature(transactionHash);
+
+ return new ECDSASignature(components[0], components[1]).toCanonicalised();
+ }
+
+ public static ECKeyPair create(KeyPair keyPair) {
+ BCECPrivateKey privateKey = (BCECPrivateKey) keyPair.getPrivate();
+ BCECPublicKey publicKey = (BCECPublicKey) keyPair.getPublic();
+
+ BigInteger privateKeyValue = privateKey.getD();
+
+ // Ethereum does not use encoded public keys like bitcoin - see
+ // https://en.bitcoin.it/wiki/Elliptic_Curve_Digital_Signature_Algorithm for details
+ // Additionally, as the first bit is a constant prefix (0x04) we ignore this value
+ byte[] publicKeyBytes = publicKey.getQ().getEncoded(false);
+ BigInteger publicKeyValue =
+ new BigInteger(1, Arrays.copyOfRange(publicKeyBytes, 1, publicKeyBytes.length));
+
+ return new ECKeyPair(privateKeyValue, publicKeyValue);
+ }
+
+ public static ECKeyPair create(BigInteger privateKey) {
+ return new ECKeyPair(privateKey, Sign.publicKeyFromPrivate(privateKey));
+ }
+
+ public static ECKeyPair create(byte[] privateKey) {
+ return create(Numeric.toBigInt(privateKey));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ ECKeyPair ecKeyPair = (ECKeyPair) o;
+
+ if (privateKey != null
+ ? !privateKey.equals(ecKeyPair.privateKey)
+ : ecKeyPair.privateKey != null) {
+ return false;
+ }
+
+ return publicKey != null ? publicKey.equals(ecKeyPair.publicKey) : ecKeyPair.publicKey == null;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = privateKey != null ? privateKey.hashCode() : 0;
+ result = 31 * result + (publicKey != null ? publicKey.hashCode() : 0);
+ return result;
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/crypto/Hash.java b/p2p/src/main/java/org/web3j/crypto/Hash.java
new file mode 100644
index 00000000000..7c4e6d27d1b
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/crypto/Hash.java
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.crypto;
+
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import org.bouncycastle.crypto.digests.RIPEMD160Digest;
+import org.bouncycastle.crypto.digests.SHA512Digest;
+import org.bouncycastle.crypto.macs.HMac;
+import org.bouncycastle.crypto.params.KeyParameter;
+import org.bouncycastle.jcajce.provider.digest.Blake2b;
+import org.bouncycastle.jcajce.provider.digest.Keccak;
+import org.web3j.utils.Numeric;
+
+/** Cryptographic hash functions. */
+public class Hash {
+ private Hash() {}
+
+ /**
+ * Generates a digest for the given {@code input}.
+ *
+ * @param input The input to digest
+ * @param algorithm The hash algorithm to use
+ * @return The hash value for the given input
+ * @throws RuntimeException If we couldn't find any provider for the given algorithm
+ */
+ public static byte[] hash(byte[] input, String algorithm) {
+ try {
+ MessageDigest digest = MessageDigest.getInstance(algorithm.toUpperCase());
+ return digest.digest(input);
+ } catch (NoSuchAlgorithmException e) {
+ throw new RuntimeException("Couldn't find a " + algorithm + " provider", e);
+ }
+ }
+
+ /**
+ * Keccak-256 hash function.
+ *
+ * @param hexInput hex encoded input data with optional 0x prefix
+ * @return hash value as hex encoded string
+ */
+ public static String sha3(String hexInput) {
+ byte[] bytes = Numeric.hexStringToByteArray(hexInput);
+ byte[] result = sha3(bytes);
+ return Numeric.toHexString(result);
+ }
+
+ /**
+ * Keccak-256 hash function.
+ *
+ * @param input binary encoded input data
+ * @param offset of start of data
+ * @param length of data
+ * @return hash value
+ */
+ public static byte[] sha3(byte[] input, int offset, int length) {
+ Keccak.DigestKeccak kecc = new Keccak.Digest256();
+ kecc.update(input, offset, length);
+ return kecc.digest();
+ }
+
+ /**
+ * Keccak-256 hash function.
+ *
+ * @param input binary encoded input data
+ * @return hash value
+ */
+ public static byte[] sha3(byte[] input) {
+ return sha3(input, 0, input.length);
+ }
+
+ /**
+ * Keccak-256 hash function that operates on a UTF-8 encoded String.
+ *
+ * @param utf8String UTF-8 encoded string
+ * @return hash value as hex encoded string
+ */
+ public static String sha3String(String utf8String) {
+ return Numeric.toHexString(sha3(utf8String.getBytes(StandardCharsets.UTF_8)));
+ }
+
+ /**
+ * Generates SHA-256 digest for the given {@code input}.
+ *
+ * @param input The input to digest
+ * @return The hash value for the given input
+ * @throws RuntimeException If we couldn't find any SHA-256 provider
+ */
+ public static byte[] sha256(byte[] input) {
+ try {
+ MessageDigest digest = MessageDigest.getInstance("SHA-256");
+ return digest.digest(input);
+ } catch (NoSuchAlgorithmException e) {
+ throw new RuntimeException("Couldn't find a SHA-256 provider", e);
+ }
+ }
+
+ public static byte[] hmacSha512(byte[] key, byte[] input) {
+ HMac hMac = new HMac(new SHA512Digest());
+ hMac.init(new KeyParameter(key));
+ hMac.update(input, 0, input.length);
+ byte[] out = new byte[64];
+ hMac.doFinal(out, 0);
+ return out;
+ }
+
+ public static byte[] sha256hash160(byte[] input) {
+ byte[] sha256 = sha256(input);
+ RIPEMD160Digest digest = new RIPEMD160Digest();
+ digest.update(sha256, 0, sha256.length);
+ byte[] out = new byte[20];
+ digest.doFinal(out, 0);
+ return out;
+ }
+
+ /**
+ * Blake2-256 hash function.
+ *
+ * @param input binary encoded input data
+ * @return hash value
+ */
+ public static byte[] blake2b256(byte[] input) {
+ return new Blake2b.Blake2b256().digest(input);
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/crypto/Sign.java b/p2p/src/main/java/org/web3j/crypto/Sign.java
new file mode 100644
index 00000000000..f75dda8c7b4
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/crypto/Sign.java
@@ -0,0 +1,358 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.crypto;
+
+import static org.web3j.utils.Assertions.verifyPrecondition;
+
+import java.math.BigInteger;
+import java.security.SignatureException;
+import java.util.Arrays;
+import org.bouncycastle.asn1.x9.X9ECParameters;
+import org.bouncycastle.asn1.x9.X9IntegerConverter;
+import org.bouncycastle.crypto.ec.CustomNamedCurves;
+import org.bouncycastle.crypto.params.ECDomainParameters;
+import org.bouncycastle.math.ec.ECAlgorithms;
+import org.bouncycastle.math.ec.ECPoint;
+import org.bouncycastle.math.ec.FixedPointCombMultiplier;
+import org.bouncycastle.math.ec.custom.sec.SecP256K1Curve;
+import org.web3j.utils.Numeric;
+
+/**
+ * Transaction signing logic.
+ *
+ * Adapted from the
+ * BitcoinJ ECKey implementation.
+ */
+public class Sign {
+
+ public static final X9ECParameters CURVE_PARAMS = CustomNamedCurves.getByName("secp256k1");
+ static final ECDomainParameters CURVE =
+ new ECDomainParameters(
+ CURVE_PARAMS.getCurve(), CURVE_PARAMS.getG(), CURVE_PARAMS.getN(), CURVE_PARAMS.getH());
+ static final BigInteger HALF_CURVE_ORDER = CURVE_PARAMS.getN().shiftRight(1);
+
+ static final String MESSAGE_PREFIX = "\u0019Ethereum Signed Message:\n";
+
+ static byte[] getEthereumMessagePrefix(int messageLength) {
+ return MESSAGE_PREFIX.concat(String.valueOf(messageLength)).getBytes();
+ }
+
+ static byte[] getEthereumMessageHash(byte[] message) {
+ byte[] prefix = getEthereumMessagePrefix(message.length);
+
+ byte[] result = new byte[prefix.length + message.length];
+ System.arraycopy(prefix, 0, result, 0, prefix.length);
+ System.arraycopy(message, 0, result, prefix.length, message.length);
+
+ return Hash.sha3(result);
+ }
+
+ public static SignatureData signPrefixedMessage(byte[] message, ECKeyPair keyPair) {
+ return signMessage(getEthereumMessageHash(message), keyPair, false);
+ }
+
+ public static SignatureData signMessage(byte[] message, ECKeyPair keyPair) {
+ return signMessage(message, keyPair, true);
+ }
+
+ public static SignatureData signMessage(byte[] message, ECKeyPair keyPair, boolean needToHash) {
+ BigInteger publicKey = keyPair.getPublicKey();
+ byte[] messageHash;
+ if (needToHash) {
+ messageHash = Hash.sha3(message);
+ } else {
+ messageHash = message;
+ }
+
+ ECDSASignature sig = keyPair.sign(messageHash);
+ // Now we have to work backwards to figure out the recId needed to recover the signature.
+ int recId = -1;
+ for (int i = 0; i < 4; i++) {
+ BigInteger k = recoverFromSignature(i, sig, messageHash);
+ if (k != null && k.equals(publicKey)) {
+ recId = i;
+ break;
+ }
+ }
+ if (recId == -1) {
+ throw new RuntimeException(
+ "Could not construct a recoverable key. Are your credentials valid?");
+ }
+
+ int headerByte = recId + 27;
+
+ // 1 header + 32 bytes for R + 32 bytes for S
+ byte[] v = new byte[] {(byte) headerByte};
+ byte[] r = Numeric.toBytesPadded(sig.r, 32);
+ byte[] s = Numeric.toBytesPadded(sig.s, 32);
+
+ return new SignatureData(v, r, s);
+ }
+
+ /**
+ * Given the components of a signature and a selector value, recover and return the public key
+ * that generated the signature according to the algorithm in SEC1v2 section 4.1.6.
+ *
+ *
The recId is an index from 0 to 3 which indicates which of the 4 possible keys is the
+ * correct one. Because the key recovery operation yields multiple potential keys, the correct key
+ * must either be stored alongside the signature, or you must be willing to try each recId in turn
+ * until you find one that outputs the key you are expecting.
+ *
+ *
If this method returns null it means recovery was not possible and recId should be iterated.
+ *
+ *
Given the above two points, a correct usage of this method is inside a for loop from 0 to 3,
+ * and if the output is null OR a key that is not the one you expect, you try again with the next
+ * recId.
+ *
+ * @param recId Which possible key to recover.
+ * @param sig the R and S components of the signature, wrapped.
+ * @param message Hash of the data that was signed.
+ * @return An ECKey containing only the public part, or null if recovery wasn't possible.
+ */
+ public static BigInteger recoverFromSignature(int recId, ECDSASignature sig, byte[] message) {
+ verifyPrecondition(recId >= 0, "recId must be positive");
+ verifyPrecondition(sig.r.signum() >= 0, "r must be positive");
+ verifyPrecondition(sig.s.signum() >= 0, "s must be positive");
+ verifyPrecondition(message != null, "message cannot be null");
+
+ // 1.0 For j from 0 to h (h == recId here and the loop is outside this function)
+ // 1.1 Let x = r + jn
+ BigInteger n = CURVE.getN(); // Curve order.
+ BigInteger i = BigInteger.valueOf((long) recId / 2);
+ BigInteger x = sig.r.add(i.multiply(n));
+ // 1.2. Convert the integer x to an octet string X of length mlen using the conversion
+ // routine specified in Section 2.3.7, where mlen = ⌈(log2 p)/8⌉ or mlen = ⌈m/8⌉.
+ // 1.3. Convert the octet string (16 set binary digits)||X to an elliptic curve point R
+ // using the conversion routine specified in Section 2.3.4. If this conversion
+ // routine outputs "invalid", then do another iteration of Step 1.
+ //
+ // More concisely, what these points mean is to use X as a compressed public key.
+ BigInteger prime = SecP256K1Curve.q;
+ if (x.compareTo(prime) >= 0) {
+ // Cannot have point co-ordinates larger than this as everything takes place modulo Q.
+ return null;
+ }
+ // Compressed keys require you to know an extra bit of data about the y-coord as there are
+ // two possibilities. So it's encoded in the recId.
+ ECPoint R = decompressKey(x, (recId & 1) == 1);
+ // 1.4. If nR != point at infinity, then do another iteration of Step 1 (callers
+ // responsibility).
+ if (!R.multiply(n).isInfinity()) {
+ return null;
+ }
+ // 1.5. Compute e from M using Steps 2 and 3 of ECDSA signature verification.
+ BigInteger e = new BigInteger(1, message);
+ // 1.6. For k from 1 to 2 do the following. (loop is outside this function via
+ // iterating recId)
+ // 1.6.1. Compute a candidate public key as:
+ // Q = mi(r) * (sR - eG)
+ //
+ // Where mi(x) is the modular multiplicative inverse. We transform this into the following:
+ // Q = (mi(r) * s ** R) + (mi(r) * -e ** G)
+ // Where -e is the modular additive inverse of e, that is z such that z + e = 0 (mod n).
+ // In the above equation ** is point multiplication and + is point addition (the EC group
+ // operator).
+ //
+ // We can find the additive inverse by subtracting e from zero then taking the mod. For
+ // example the additive inverse of 3 modulo 11 is 8 because 3 + 8 mod 11 = 0, and
+ // -3 mod 11 = 8.
+ BigInteger eInv = BigInteger.ZERO.subtract(e).mod(n);
+ BigInteger rInv = sig.r.modInverse(n);
+ BigInteger srInv = rInv.multiply(sig.s).mod(n);
+ BigInteger eInvrInv = rInv.multiply(eInv).mod(n);
+ ECPoint q = ECAlgorithms.sumOfTwoMultiplies(CURVE.getG(), eInvrInv, R, srInv);
+
+ byte[] qBytes = q.getEncoded(false);
+ // We remove the prefix
+ return new BigInteger(1, Arrays.copyOfRange(qBytes, 1, qBytes.length));
+ }
+
+ /** Decompress a compressed public key (x co-ord and low-bit of y-coord). */
+ private static ECPoint decompressKey(BigInteger xBN, boolean yBit) {
+ X9IntegerConverter x9 = new X9IntegerConverter();
+ byte[] compEnc = x9.integerToBytes(xBN, 1 + x9.getByteLength(CURVE.getCurve()));
+ compEnc[0] = (byte) (yBit ? 0x03 : 0x02);
+ return CURVE.getCurve().decodePoint(compEnc);
+ }
+
+ /**
+ * Given an arbitrary piece of text and an Ethereum message signature encoded in bytes, returns
+ * the public key that was used to sign it. This can then be compared to the expected public key
+ * to determine if the signature was correct.
+ *
+ * @param message RLP encoded message.
+ * @param signatureData The message signature components
+ * @return the public key used to sign the message
+ * @throws SignatureException If the public key could not be recovered or if there was a signature
+ * format error.
+ */
+ public static BigInteger signedMessageToKey(byte[] message, SignatureData signatureData)
+ throws SignatureException {
+ return signedMessageHashToKey(Hash.sha3(message), signatureData);
+ }
+
+ /**
+ * Given an arbitrary message and an Ethereum message signature encoded in bytes, returns the
+ * public key that was used to sign it. This can then be compared to the expected public key to
+ * determine if the signature was correct.
+ *
+ * @param message The message.
+ * @param signatureData The message signature components
+ * @return the public key used to sign the message
+ * @throws SignatureException If the public key could not be recovered or if there was a signature
+ * format error.
+ */
+ public static BigInteger signedPrefixedMessageToKey(byte[] message, SignatureData signatureData)
+ throws SignatureException {
+ return signedMessageHashToKey(getEthereumMessageHash(message), signatureData);
+ }
+
+ /**
+ * Given an arbitrary message hash and an Ethereum message signature encoded in bytes, returns the
+ * public key that was used to sign it. This can then be compared to the expected public key to
+ * determine if the signature was correct.
+ *
+ * @param messageHash The message hash.
+ * @param signatureData The message signature components
+ * @return the public key used to sign the message
+ * @throws SignatureException If the public key could not be recovered or if there was a signature
+ * format error.
+ */
+ public static BigInteger signedMessageHashToKey(byte[] messageHash, SignatureData signatureData)
+ throws SignatureException {
+
+ byte[] r = signatureData.getR();
+ byte[] s = signatureData.getS();
+ verifyPrecondition(r != null && r.length == 32, "r must be 32 bytes");
+ verifyPrecondition(s != null && s.length == 32, "s must be 32 bytes");
+
+ int header = signatureData.getV()[0] & 0xFF;
+ // The header byte: 0x1B = first key with even y, 0x1C = first key with odd y,
+ // 0x1D = second key with even y, 0x1E = second key with odd y
+ if (header < 27 || header > 34) {
+ throw new SignatureException("Header byte out of range: " + header);
+ }
+
+ ECDSASignature sig =
+ new ECDSASignature(
+ new BigInteger(1, signatureData.getR()), new BigInteger(1, signatureData.getS()));
+
+ int recId = header - 27;
+ BigInteger key = recoverFromSignature(recId, sig, messageHash);
+ if (key == null) {
+ throw new SignatureException("Could not recover public key from signature");
+ }
+ return key;
+ }
+
+ /**
+ * Returns public key from the given private key.
+ *
+ * @param privKey the private key to derive the public key from
+ * @return BigInteger encoded public key
+ */
+ public static BigInteger publicKeyFromPrivate(BigInteger privKey) {
+ ECPoint point = publicPointFromPrivate(privKey);
+
+ byte[] encoded = point.getEncoded(false);
+ return new BigInteger(1, Arrays.copyOfRange(encoded, 1, encoded.length)); // remove prefix
+ }
+
+ /**
+ * Returns public key point from the given private key.
+ *
+ * @param privKey the private key to derive the public key from
+ * @return ECPoint public key
+ */
+ public static ECPoint publicPointFromPrivate(BigInteger privKey) {
+ /*
+ * TODO: FixedPointCombMultiplier currently doesn't support scalars longer than the group
+ * order, but that could change in future versions.
+ */
+ if (privKey.bitLength() > CURVE.getN().bitLength()) {
+ privKey = privKey.mod(CURVE.getN());
+ }
+ return new FixedPointCombMultiplier().multiply(CURVE.getG(), privKey);
+ }
+
+ /**
+ * Returns public key point from the given curve.
+ *
+ * @param bits representing the point on the curve
+ * @return BigInteger encoded public key
+ */
+ public static BigInteger publicFromPoint(byte[] bits) {
+ return new BigInteger(1, Arrays.copyOfRange(bits, 1, bits.length)); // remove prefix
+ }
+
+ public static class SignatureData {
+ private final byte[] v;
+ private final byte[] r;
+ private final byte[] s;
+
+ public SignatureData(byte v, byte[] r, byte[] s) {
+ this(new byte[] {v}, r, s);
+ }
+
+ public SignatureData(byte[] v, byte[] r, byte[] s) {
+ this.v = v;
+ this.r = r;
+ this.s = s;
+ }
+
+ public byte[] getV() {
+ return v;
+ }
+
+ public byte[] getR() {
+ return r;
+ }
+
+ public byte[] getS() {
+ return s;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ SignatureData that = (SignatureData) o;
+
+ if (!Arrays.equals(v, that.v)) {
+ return false;
+ }
+ if (!Arrays.equals(r, that.r)) {
+ return false;
+ }
+ return Arrays.equals(s, that.s);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Arrays.hashCode(v);
+ result = 31 * result + Arrays.hashCode(r);
+ result = 31 * result + Arrays.hashCode(s);
+ return result;
+ }
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/exceptions/MessageDecodingException.java b/p2p/src/main/java/org/web3j/exceptions/MessageDecodingException.java
new file mode 100644
index 00000000000..e2232209d19
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/exceptions/MessageDecodingException.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.exceptions;
+
+/** Encoding exception. */
+public class MessageDecodingException extends RuntimeException {
+ public MessageDecodingException(String message) {
+ super(message);
+ }
+
+ public MessageDecodingException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/exceptions/MessageEncodingException.java b/p2p/src/main/java/org/web3j/exceptions/MessageEncodingException.java
new file mode 100644
index 00000000000..953a031e45d
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/exceptions/MessageEncodingException.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.exceptions;
+
+/** Encoding exception. */
+public class MessageEncodingException extends RuntimeException {
+ public MessageEncodingException(String message) {
+ super(message);
+ }
+
+ public MessageEncodingException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/utils/Assertions.java b/p2p/src/main/java/org/web3j/utils/Assertions.java
new file mode 100644
index 00000000000..77f0b7ad651
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/utils/Assertions.java
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.utils;
+
+/** Assertion utility functions. */
+public class Assertions {
+
+ /**
+ * Verify that the provided precondition holds true.
+ *
+ * @param assertionResult assertion value
+ * @param errorMessage error message if precondition failure
+ */
+ public static void verifyPrecondition(boolean assertionResult, String errorMessage) {
+ if (!assertionResult) {
+ throw new RuntimeException(errorMessage);
+ }
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/utils/Numeric.java b/p2p/src/main/java/org/web3j/utils/Numeric.java
new file mode 100644
index 00000000000..31fef2a2513
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/utils/Numeric.java
@@ -0,0 +1,254 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.utils;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.util.Arrays;
+import org.web3j.exceptions.MessageDecodingException;
+import org.web3j.exceptions.MessageEncodingException;
+
+/**
+ * Message codec functions.
+ *
+ *
Implementation as per https://github.com/ethereum/wiki/wiki/JSON-RPC#hex-value-encoding
+ */
+public final class Numeric {
+
+ private static final String HEX_PREFIX = "0x";
+
+ private Numeric() {}
+
+ public static String encodeQuantity(BigInteger value) {
+ if (value.signum() != -1) {
+ return HEX_PREFIX + value.toString(16);
+ } else {
+ throw new MessageEncodingException("Negative values are not supported");
+ }
+ }
+
+ public static BigInteger decodeQuantity(String value) {
+ if (isLongValue(value)) {
+ return BigInteger.valueOf(Long.parseLong(value));
+ }
+
+ if (!isValidHexQuantity(value)) {
+ throw new MessageDecodingException("Value must be in format 0x[1-9]+[0-9]* or 0x0");
+ }
+ try {
+ return new BigInteger(value.substring(2), 16);
+ } catch (NumberFormatException e) {
+ throw new MessageDecodingException("Negative ", e);
+ }
+ }
+
+ private static boolean isLongValue(String value) {
+ try {
+ Long.parseLong(value);
+ return true;
+ } catch (NumberFormatException e) {
+ return false;
+ }
+ }
+
+ private static boolean isValidHexQuantity(String value) {
+ if (value == null) {
+ return false;
+ }
+
+ if (value.length() < 3) {
+ return false;
+ }
+
+ if (!value.startsWith(HEX_PREFIX)) {
+ return false;
+ }
+
+ // If TestRpc resolves the following issue, we can reinstate this code
+ // https://github.com/ethereumjs/testrpc/issues/220
+ // if (value.length() > 3 && value.charAt(2) == '0') {
+ // return false;
+ // }
+
+ return true;
+ }
+
+ public static String cleanHexPrefix(String input) {
+ if (containsHexPrefix(input)) {
+ return input.substring(2);
+ } else {
+ return input;
+ }
+ }
+
+ public static String prependHexPrefix(String input) {
+ if (!containsHexPrefix(input)) {
+ return HEX_PREFIX + input;
+ } else {
+ return input;
+ }
+ }
+
+ public static boolean containsHexPrefix(String input) {
+ return !Strings.isEmpty(input)
+ && input.length() > 1
+ && input.charAt(0) == '0'
+ && input.charAt(1) == 'x';
+ }
+
+ public static BigInteger toBigInt(byte[] value, int offset, int length) {
+ return toBigInt((Arrays.copyOfRange(value, offset, offset + length)));
+ }
+
+ public static BigInteger toBigInt(byte[] value) {
+ return new BigInteger(1, value);
+ }
+
+ public static BigInteger toBigInt(String hexValue) {
+ String cleanValue = cleanHexPrefix(hexValue);
+ return toBigIntNoPrefix(cleanValue);
+ }
+
+ public static BigInteger toBigIntNoPrefix(String hexValue) {
+ return new BigInteger(hexValue, 16);
+ }
+
+ public static String toHexStringWithPrefix(BigInteger value) {
+ return HEX_PREFIX + value.toString(16);
+ }
+
+ public static String toHexStringNoPrefix(BigInteger value) {
+ return value.toString(16);
+ }
+
+ public static String toHexStringNoPrefix(byte[] input) {
+ return toHexString(input, 0, input.length, false);
+ }
+
+ public static String toHexStringWithPrefixZeroPadded(BigInteger value, int size) {
+ return toHexStringZeroPadded(value, size, true);
+ }
+
+ public static String toHexStringWithPrefixSafe(BigInteger value) {
+ String result = toHexStringNoPrefix(value);
+ if (result.length() < 2) {
+ result = Strings.zeros(1) + result;
+ }
+ return HEX_PREFIX + result;
+ }
+
+ public static String toHexStringNoPrefixZeroPadded(BigInteger value, int size) {
+ return toHexStringZeroPadded(value, size, false);
+ }
+
+ private static String toHexStringZeroPadded(BigInteger value, int size, boolean withPrefix) {
+ String result = toHexStringNoPrefix(value);
+
+ int length = result.length();
+ if (length > size) {
+ throw new UnsupportedOperationException("Value " + result + "is larger then length " + size);
+ } else if (value.signum() < 0) {
+ throw new UnsupportedOperationException("Value cannot be negative");
+ }
+
+ if (length < size) {
+ result = Strings.zeros(size - length) + result;
+ }
+
+ if (withPrefix) {
+ return HEX_PREFIX + result;
+ } else {
+ return result;
+ }
+ }
+
+ public static byte[] toBytesPadded(BigInteger value, int length) {
+ byte[] result = new byte[length];
+ byte[] bytes = value.toByteArray();
+
+ int bytesLength;
+ int srcOffset;
+ if (bytes[0] == 0) {
+ bytesLength = bytes.length - 1;
+ srcOffset = 1;
+ } else {
+ bytesLength = bytes.length;
+ srcOffset = 0;
+ }
+
+ if (bytesLength > length) {
+ throw new RuntimeException("Input is too large to put in byte array of size " + length);
+ }
+
+ int destOffset = length - bytesLength;
+ System.arraycopy(bytes, srcOffset, result, destOffset, bytesLength);
+ return result;
+ }
+
+ public static byte[] hexStringToByteArray(String input) {
+ String cleanInput = cleanHexPrefix(input);
+
+ int len = cleanInput.length();
+
+ if (len == 0) {
+ return new byte[] {};
+ }
+
+ byte[] data;
+ int startIdx;
+ if (len % 2 != 0) {
+ data = new byte[(len / 2) + 1];
+ data[0] = (byte) Character.digit(cleanInput.charAt(0), 16);
+ startIdx = 1;
+ } else {
+ data = new byte[len / 2];
+ startIdx = 0;
+ }
+
+ for (int i = startIdx; i < len; i += 2) {
+ data[(i + 1) / 2] =
+ (byte)
+ ((Character.digit(cleanInput.charAt(i), 16) << 4)
+ + Character.digit(cleanInput.charAt(i + 1), 16));
+ }
+ return data;
+ }
+
+ public static String toHexString(byte[] input, int offset, int length, boolean withPrefix) {
+ StringBuilder stringBuilder = new StringBuilder();
+ if (withPrefix) {
+ stringBuilder.append("0x");
+ }
+ for (int i = offset; i < offset + length; i++) {
+ stringBuilder.append(String.format("%02x", input[i] & 0xFF));
+ }
+
+ return stringBuilder.toString();
+ }
+
+ public static String toHexString(byte[] input) {
+ return toHexString(input, 0, input.length, true);
+ }
+
+ public static byte asByte(int m, int n) {
+ return (byte) ((m << 4) | n);
+ }
+
+ public static boolean isIntegerValue(BigDecimal value) {
+ return value.signum() == 0 || value.scale() <= 0 || value.stripTrailingZeros().scale() <= 0;
+ }
+}
diff --git a/p2p/src/main/java/org/web3j/utils/Strings.java b/p2p/src/main/java/org/web3j/utils/Strings.java
new file mode 100644
index 00000000000..34733642389
--- /dev/null
+++ b/p2p/src/main/java/org/web3j/utils/Strings.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2019 Web3 Labs Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package org.web3j.utils;
+
+import java.util.List;
+
+/** String utility functions. */
+public class Strings {
+
+ private Strings() {}
+
+ public static String toCsv(List src) {
+ // return src == null ? null : String.join(", ", src.toArray(new String[0]));
+ return join(src, ", ");
+ }
+
+ public static String join(List src, String delimiter) {
+ return src == null ? null : String.join(delimiter, src.toArray(new String[0]));
+ }
+
+ public static String capitaliseFirstLetter(String string) {
+ if (string == null || string.length() == 0) {
+ return string;
+ } else {
+ return string.substring(0, 1).toUpperCase() + string.substring(1);
+ }
+ }
+
+ public static String lowercaseFirstLetter(String string) {
+ if (string == null || string.length() == 0) {
+ return string;
+ } else {
+ return string.substring(0, 1).toLowerCase() + string.substring(1);
+ }
+ }
+
+ public static String zeros(int n) {
+ return repeat('0', n);
+ }
+
+ public static String repeat(char value, int n) {
+ return new String(new char[n]).replace("\0", String.valueOf(value));
+ }
+
+ public static boolean isEmpty(String s) {
+ return s == null || s.length() == 0;
+ }
+}
diff --git a/p2p/src/main/protos/Connect.proto b/p2p/src/main/protos/Connect.proto
new file mode 100644
index 00000000000..d03d123a963
--- /dev/null
+++ b/p2p/src/main/protos/Connect.proto
@@ -0,0 +1,60 @@
+syntax = "proto3";
+
+import "Discover.proto";
+
+option java_package = "org.tron.p2p.protos";
+option java_outer_classname = "Connect";
+
+message KeepAliveMessage {
+ int64 timestamp = 1;
+}
+
+message HelloMessage {
+ Endpoint from = 1;
+ int32 network_id = 2;
+ int32 code = 3;
+ int64 timestamp = 4;
+ int32 version = 5;
+}
+
+message StatusMessage {
+ Endpoint from = 1;
+ int32 version = 2;
+ int32 network_id = 3;
+ int32 maxConnections = 4;
+ int32 currentConnections = 5;
+ int64 timestamp = 6;
+}
+
+message CompressMessage {
+ enum CompressType {
+ uncompress = 0;
+ snappy = 1;
+ }
+
+ CompressType type = 1;
+ bytes data = 2;
+}
+
+enum DisconnectReason {
+ PEER_QUITING = 0x00;
+ BAD_PROTOCOL = 0x01;
+ TOO_MANY_PEERS = 0x02;
+ DUPLICATE_PEER = 0x03;
+ DIFFERENT_VERSION = 0x04;
+ RANDOM_ELIMINATION = 0x05;
+ EMPTY_MESSAGE = 0X06;
+ PING_TIMEOUT = 0x07;
+ DISCOVER_MODE = 0x08;
+ //DETECT_COMPLETE = 0x09;
+ NO_SUCH_MESSAGE = 0x0A;
+ BAD_MESSAGE = 0x0B;
+ TOO_MANY_PEERS_WITH_SAME_IP = 0x0C;
+ RECENT_DISCONNECT = 0x0D;
+ DUP_HANDSHAKE = 0x0E;
+ UNKNOWN = 0xFF;
+}
+
+message P2pDisconnectMessage {
+ DisconnectReason reason = 1;
+}
\ No newline at end of file
diff --git a/p2p/src/main/protos/Discover.proto b/p2p/src/main/protos/Discover.proto
new file mode 100644
index 00000000000..8a53761115c
--- /dev/null
+++ b/p2p/src/main/protos/Discover.proto
@@ -0,0 +1,50 @@
+syntax = "proto3";
+
+option java_package = "org.tron.p2p.protos";
+option java_outer_classname = "Discover";
+
+message Endpoint {
+ bytes address = 1;
+ int32 port = 2;
+ bytes nodeId = 3;
+ bytes addressIpv6 = 4;
+}
+
+message PingMessage {
+ Endpoint from = 1;
+ Endpoint to = 2;
+ int32 version = 3;
+ int64 timestamp = 4;
+}
+
+message PongMessage {
+ Endpoint from = 1;
+ int32 echo = 2;
+ int64 timestamp = 3;
+}
+
+message FindNeighbours {
+ Endpoint from = 1;
+ bytes targetId = 2;
+ int64 timestamp = 3;
+}
+
+message Neighbours {
+ Endpoint from = 1;
+ repeated Endpoint neighbours = 2;
+ int64 timestamp = 3;
+}
+
+message EndPoints {
+ repeated Endpoint nodes = 1;
+}
+
+message DnsRoot {
+ message TreeRoot {
+ bytes eRoot = 1;
+ bytes lRoot = 2;
+ int32 seq = 3;
+ }
+ TreeRoot treeRoot = 1;
+ bytes signature = 2;
+}
diff --git a/p2p/src/test/java/org/tron/p2p/P2pConfigTest.java b/p2p/src/test/java/org/tron/p2p/P2pConfigTest.java
new file mode 100644
index 00000000000..1933b8f47a9
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/P2pConfigTest.java
@@ -0,0 +1,77 @@
+package org.tron.p2p;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class P2pConfigTest {
+
+ @Test
+ public void testDefaultValues() {
+ P2pConfig config = new P2pConfig();
+ Assert.assertNotNull(config.getSeedNodes());
+ Assert.assertTrue(config.getSeedNodes().isEmpty());
+ Assert.assertNotNull(config.getActiveNodes());
+ Assert.assertTrue(config.getActiveNodes().isEmpty());
+ Assert.assertNotNull(config.getTrustNodes());
+ Assert.assertTrue(config.getTrustNodes().isEmpty());
+ Assert.assertNotNull(config.getNodeID());
+ Assert.assertEquals(64, config.getNodeID().length);
+ Assert.assertEquals(18888, config.getPort());
+ Assert.assertEquals(1, config.getNetworkId());
+ Assert.assertEquals(8, config.getMinConnections());
+ Assert.assertEquals(50, config.getMaxConnections());
+ Assert.assertEquals(2, config.getMinActiveConnections());
+ Assert.assertEquals(2, config.getMaxConnectionsWithSameIp());
+ Assert.assertTrue(config.isDiscoverEnable());
+ Assert.assertFalse(config.isDisconnectionPolicyEnable());
+ Assert.assertFalse(config.isNodeDetectEnable());
+ Assert.assertNotNull(config.getTreeUrls());
+ Assert.assertTrue(config.getTreeUrls().isEmpty());
+ Assert.assertNotNull(config.getPublishConfig());
+ }
+
+ @Test
+ public void testSettersAndGetters() {
+ P2pConfig config = new P2pConfig();
+
+ config.setPort(19999);
+ Assert.assertEquals(19999, config.getPort());
+
+ config.setNetworkId(42);
+ Assert.assertEquals(42, config.getNetworkId());
+
+ config.setMinConnections(10);
+ Assert.assertEquals(10, config.getMinConnections());
+
+ config.setMaxConnections(100);
+ Assert.assertEquals(100, config.getMaxConnections());
+
+ config.setMinActiveConnections(5);
+ Assert.assertEquals(5, config.getMinActiveConnections());
+
+ config.setMaxConnectionsWithSameIp(3);
+ Assert.assertEquals(3, config.getMaxConnectionsWithSameIp());
+
+ config.setDiscoverEnable(false);
+ Assert.assertFalse(config.isDiscoverEnable());
+
+ config.setDisconnectionPolicyEnable(true);
+ Assert.assertTrue(config.isDisconnectionPolicyEnable());
+
+ config.setNodeDetectEnable(true);
+ Assert.assertTrue(config.isNodeDetectEnable());
+
+ byte[] customId = new byte[64];
+ config.setNodeID(customId);
+ Assert.assertArrayEquals(customId, config.getNodeID());
+
+ config.setIp("10.0.0.1");
+ Assert.assertEquals("10.0.0.1", config.getIp());
+
+ config.setLanIp("192.168.0.1");
+ Assert.assertEquals("192.168.0.1", config.getLanIp());
+
+ config.setIpv6("::1");
+ Assert.assertEquals("::1", config.getIpv6());
+ }
+}
diff --git a/p2p/src/test/java/org/tron/p2p/P2pServiceTest.java b/p2p/src/test/java/org/tron/p2p/P2pServiceTest.java
new file mode 100644
index 00000000000..99e8cd35f18
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/P2pServiceTest.java
@@ -0,0 +1,104 @@
+package org.tron.p2p;
+
+import java.lang.reflect.Field;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.exception.P2pException;
+
+public class P2pServiceTest {
+
+ private P2pService p2pService;
+
+ @Before
+ public void init() {
+ p2pService = new P2pService();
+ // Reset handler state
+ Parameter.handlerList = new ArrayList<>();
+ Parameter.handlerMap = new HashMap<>();
+ }
+
+ @After
+ public void cleanup() {
+ try {
+ p2pService.close();
+ } catch (Exception e) {
+ // ignore cleanup errors
+ }
+ }
+
+ @Test
+ public void testGetVersion() {
+ Assert.assertEquals(Parameter.version, p2pService.getVersion());
+ }
+
+ @Test
+ public void testRegisterHandler() throws P2pException {
+ P2pEventHandler handler = new P2pEventHandler() {
+ {
+ Set types = new HashSet<>();
+ types.add((byte) 0x50);
+ this.messageTypes = types;
+ }
+ };
+ p2pService.register(handler);
+ Assert.assertTrue(Parameter.handlerList.contains(handler));
+ Assert.assertEquals(handler, Parameter.handlerMap.get((byte) 0x50));
+ }
+
+ @Test(expected = P2pException.class)
+ public void testRegisterDuplicateTypeThrows() throws P2pException {
+ P2pEventHandler handler1 = new P2pEventHandler() {
+ {
+ Set types = new HashSet<>();
+ types.add((byte) 0x60);
+ this.messageTypes = types;
+ }
+ };
+ P2pEventHandler handler2 = new P2pEventHandler() {
+ {
+ Set types = new HashSet<>();
+ types.add((byte) 0x60);
+ this.messageTypes = types;
+ }
+ };
+ p2pService.register(handler1);
+ p2pService.register(handler2); // should throw
+ }
+
+ @Test
+ public void testRegisterHandlerWithNullMessageTypes() throws P2pException {
+ P2pEventHandler handler = new P2pEventHandler() {};
+ // messageTypes is null by default
+ p2pService.register(handler);
+ Assert.assertTrue(Parameter.handlerList.contains(handler));
+ }
+
+ @Test
+ public void testCloseIdempotent() throws Exception {
+ // Set up minimal config to allow close without NPE
+ // The close method checks isShutdown flag
+ Field isShutdownField = P2pService.class.getDeclaredField("isShutdown");
+ isShutdownField.setAccessible(true);
+
+ // First close
+ isShutdownField.set(p2pService, false);
+ // We can't call start() without real network, but we can test the idempotent close
+ isShutdownField.set(p2pService, true);
+ // Second close should be a no-op
+ p2pService.close();
+ Assert.assertTrue((boolean) isShutdownField.get(p2pService));
+ }
+
+ @Test
+ public void testGetP2pStats() {
+ // statsManager is initialized in constructor, getP2pStats should work
+ Assert.assertNotNull(p2pService.getP2pStats());
+ }
+}
diff --git a/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerExtraTest.java b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerExtraTest.java
new file mode 100644
index 00000000000..bb69989eee0
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerExtraTest.java
@@ -0,0 +1,390 @@
+package org.tron.p2p.connection;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import java.lang.reflect.Field;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.HashSet;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.tron.p2p.P2pConfig;
+import org.tron.p2p.P2pEventHandler;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.connection.business.handshake.DisconnectCode;
+import org.tron.p2p.connection.business.handshake.HandshakeService;
+import org.tron.p2p.connection.business.keepalive.KeepAliveService;
+import org.tron.p2p.connection.message.keepalive.PingMessage;
+import org.tron.p2p.connection.message.keepalive.PongMessage;
+import org.tron.p2p.exception.P2pException;
+import org.tron.p2p.protos.Connect.DisconnectReason;
+
+public class ChannelManagerExtraTest {
+
+ @Before
+ public void setUp() throws Exception {
+ Parameter.p2pConfig = new P2pConfig();
+ Parameter.handlerList = new ArrayList<>();
+ Parameter.handlerMap = new java.util.HashMap<>();
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ // Initialize static services needed by processMessage
+ setStaticField(ChannelManager.class, "keepAliveService", new KeepAliveService());
+ setStaticField(ChannelManager.class, "handshakeService", new HandshakeService());
+ }
+
+ @After
+ public void tearDown() {
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ Parameter.handlerList = new ArrayList<>();
+ Parameter.handlerMap = new java.util.HashMap<>();
+ }
+
+ @Test
+ public void testGetDisconnectReasonDifferentVersion() {
+ Assert.assertEquals(DisconnectReason.DIFFERENT_VERSION,
+ ChannelManager.getDisconnectReason(DisconnectCode.DIFFERENT_VERSION));
+ }
+
+ @Test
+ public void testGetDisconnectReasonTimeBanned() {
+ Assert.assertEquals(DisconnectReason.RECENT_DISCONNECT,
+ ChannelManager.getDisconnectReason(DisconnectCode.TIME_BANNED));
+ }
+
+ @Test
+ public void testGetDisconnectReasonDuplicatePeer() {
+ Assert.assertEquals(DisconnectReason.DUPLICATE_PEER,
+ ChannelManager.getDisconnectReason(DisconnectCode.DUPLICATE_PEER));
+ }
+
+ @Test
+ public void testGetDisconnectReasonTooManyPeers() {
+ Assert.assertEquals(DisconnectReason.TOO_MANY_PEERS,
+ ChannelManager.getDisconnectReason(DisconnectCode.TOO_MANY_PEERS));
+ }
+
+ @Test
+ public void testGetDisconnectReasonMaxConnectionWithSameIp() {
+ Assert.assertEquals(DisconnectReason.TOO_MANY_PEERS_WITH_SAME_IP,
+ ChannelManager.getDisconnectReason(DisconnectCode.MAX_CONNECTION_WITH_SAME_IP));
+ }
+
+ @Test
+ public void testGetDisconnectReasonUnknown() {
+ Assert.assertEquals(DisconnectReason.UNKNOWN,
+ ChannelManager.getDisconnectReason(DisconnectCode.UNKNOWN));
+ }
+
+ @Test
+ public void testGetDisconnectReasonNormal() {
+ Assert.assertEquals(DisconnectReason.UNKNOWN,
+ ChannelManager.getDisconnectReason(DisconnectCode.NORMAL));
+ }
+
+ @Test
+ public void testBanNodeNewBan() throws Exception {
+ InetAddress addr = InetAddress.getByName("10.0.0.1");
+ ChannelManager.banNode(addr, 10000L);
+ Long banTime = ChannelManager.getBannedNodes().getIfPresent(addr);
+ Assert.assertNotNull(banTime);
+ Assert.assertTrue(banTime > System.currentTimeMillis());
+ }
+
+ @Test
+ public void testBanNodeAlreadyBannedFuture() throws Exception {
+ InetAddress addr = InetAddress.getByName("10.0.0.2");
+ // Ban with a very long time first
+ ChannelManager.banNode(addr, 100000L);
+ Long firstBan = ChannelManager.getBannedNodes().getIfPresent(addr);
+
+ // Try to ban again with shorter time; should not overwrite since existing ban is in the future
+ ChannelManager.banNode(addr, 1L);
+ Long secondBan = ChannelManager.getBannedNodes().getIfPresent(addr);
+ Assert.assertEquals(firstBan, secondBan);
+ }
+
+ @Test
+ public void testNotifyDisconnectNullAddress() {
+ Channel channel = new Channel();
+ // inetSocketAddress is null by default
+ ChannelManager.notifyDisconnect(channel);
+ // Should not throw, just log and return
+ }
+
+ @Test
+ public void testNotifyDisconnectWithHandlers() throws Exception {
+ final boolean[] called = {false};
+ P2pEventHandler handler = new P2pEventHandler() {
+ {
+ this.messageTypes = new HashSet<>();
+ }
+
+ @Override
+ public void onDisconnect(Channel channel) {
+ called[0] = true;
+ }
+ };
+ Parameter.handlerList.add(handler);
+
+ Channel channel = createChannelWithAddress("10.0.0.3", 100);
+ ChannelManager.getChannels().put(channel.getInetSocketAddress(), channel);
+
+ ChannelManager.notifyDisconnect(channel);
+
+ Assert.assertTrue(called[0]);
+ Assert.assertFalse(ChannelManager.getChannels().containsKey(channel.getInetSocketAddress()));
+ }
+
+ @Test(expected = P2pException.class)
+ public void testProcessMessageNullData() throws Exception {
+ Channel channel = new Channel();
+ ChannelManager.processMessage(channel, null);
+ }
+
+ @Test(expected = P2pException.class)
+ public void testProcessMessageEmptyData() throws Exception {
+ Channel channel = new Channel();
+ ChannelManager.processMessage(channel, new byte[0]);
+ }
+
+ @Test(expected = P2pException.class)
+ public void testProcessMessagePositiveByteNoHandler() throws Exception {
+ Channel channel = new Channel();
+ // data[0] >= 0 means it goes to handMessage, which needs a handler
+ byte[] data = new byte[]{0x01, 0x02};
+ ChannelManager.processMessage(channel, data);
+ }
+
+ @Test
+ public void testProcessMessagePositiveByteDiscoveryMode() throws Exception {
+ // Register a handler for type 0x01
+ P2pEventHandler handler = new P2pEventHandler() {
+ {
+ this.messageTypes = new HashSet<>();
+ this.messageTypes.add((byte) 0x01);
+ }
+
+ @Override
+ public void onMessage(Channel channel, byte[] data) {
+ // do nothing
+ }
+ };
+ Parameter.handlerMap.put((byte) 0x01, handler);
+
+ // Create a channel in discovery mode
+ Channel channel = createChannelWithMockCtx("10.0.0.5", 200);
+ channel.setDiscoveryMode(true);
+
+ byte[] data = new byte[]{0x01, 0x02};
+ ChannelManager.processMessage(channel, data);
+ // Should send disconnect and close
+ }
+
+ @Test
+ public void testProcessMessageKeepAlivePing() throws Exception {
+ // Create a ping message and encode it
+ PingMessage ping = new PingMessage();
+ byte[] sendData = ping.getSendData();
+
+ Channel channel = createChannelWithMockCtx("10.0.0.10", 300);
+ ChannelManager.processMessage(channel, sendData);
+ // Should process without exception (sends pong)
+ }
+
+ @Test
+ public void testProcessMessageKeepAlivePong() throws Exception {
+ PongMessage pong = new PongMessage();
+ byte[] sendData = pong.getSendData();
+
+ Channel channel = createChannelWithMockCtx("10.0.0.11", 301);
+ channel.pingSent = System.currentTimeMillis();
+ channel.waitForPong = true;
+ ChannelManager.processMessage(channel, sendData);
+
+ Assert.assertFalse(channel.waitForPong);
+ }
+
+ @Test
+ public synchronized void testProcessPeerTimeBanned() throws Exception {
+ ChannelManager.getChannels().clear();
+ Parameter.p2pConfig.setMaxConnections(50);
+ Parameter.p2pConfig.setMaxConnectionsWithSameIp(2);
+
+ InetAddress addr = InetAddress.getByName("10.0.0.20");
+ // Ban the node with future timestamp
+ ChannelManager.getBannedNodes().put(addr, System.currentTimeMillis() + 100000);
+
+ Channel channel = new Channel();
+ InetSocketAddress sockAddr = new InetSocketAddress(addr, 100);
+ setFieldValue(channel, "inetSocketAddress", sockAddr);
+ setFieldValue(channel, "inetAddress", addr);
+
+ DisconnectCode code = ChannelManager.processPeer(channel);
+ Assert.assertEquals(DisconnectCode.TIME_BANNED, code);
+ }
+
+ @Test
+ public synchronized void testProcessPeerDuplicateClosesOlder() throws Exception {
+ ChannelManager.getChannels().clear();
+ Parameter.p2pConfig.setMaxConnections(50);
+ Parameter.p2pConfig.setMaxConnectionsWithSameIp(10);
+
+ // c1 is the existing channel (started earlier)
+ Channel c1 = createChannelWithMockCtx("10.0.0.30", 100);
+ c1.setNodeId("sameNodeId");
+
+ // Wait a bit so c2 starts later
+ Thread.sleep(5);
+
+ Channel c2 = createChannelWithMockCtx("10.0.0.31", 101);
+ c2.setNodeId("sameNodeId");
+
+ ChannelManager.getChannels().put(c1.getInetSocketAddress(), c1);
+
+ // c2 processing should detect duplicate; c1 started first so c2 is newer,
+ // c1 has earlier startTime so c2 should be rejected as DUPLICATE_PEER
+ DisconnectCode code = ChannelManager.processPeer(c2);
+ Assert.assertEquals(DisconnectCode.DUPLICATE_PEER, code);
+ }
+
+ @Test
+ public synchronized void testUpdateNodeIdSelf() throws Exception {
+ ChannelManager.getChannels().clear();
+ String selfNodeId = org.bouncycastle.util.encoders.Hex.toHexString(
+ Parameter.p2pConfig.getNodeID());
+
+ Channel channel = createChannelWithMockCtx("10.0.0.40", 100);
+ ChannelManager.getChannels().put(channel.getInetSocketAddress(), channel);
+
+ ChannelManager.updateNodeId(channel, selfNodeId);
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public synchronized void testUpdateNodeIdDuplicateClosesLater() throws Exception {
+ ChannelManager.getChannels().clear();
+
+ Channel c1 = createChannelWithMockCtx("10.0.0.50", 100);
+ c1.setNodeId("dupNode");
+ ChannelManager.getChannels().put(c1.getInetSocketAddress(), c1);
+
+ Thread.sleep(5);
+
+ Channel c2 = createChannelWithMockCtx("10.0.0.51", 101);
+ c2.setNodeId("dupNode");
+ ChannelManager.getChannels().put(c2.getInetSocketAddress(), c2);
+
+ // updateNodeId should close the one that started later
+ ChannelManager.updateNodeId(c2, "dupNode");
+ // One of them should be disconnected
+ Assert.assertTrue(c1.isDisconnect() || c2.isDisconnect());
+ }
+
+ @Test
+ public synchronized void testUpdateNodeIdNoDuplicate() throws Exception {
+ ChannelManager.getChannels().clear();
+
+ Channel c1 = createChannelWithMockCtx("10.0.0.60", 100);
+ c1.setNodeId("uniqueNode");
+ ChannelManager.getChannels().put(c1.getInetSocketAddress(), c1);
+
+ ChannelManager.updateNodeId(c1, "uniqueNode");
+ // Only 1 channel with this nodeId, should not close
+ Assert.assertFalse(c1.isDisconnect());
+ }
+
+ @Test
+ public void testHandMessageWithHandlerAndFirstMessage() throws Exception {
+ final boolean[] messageCalled = {false};
+ P2pEventHandler handler = new P2pEventHandler() {
+ {
+ this.messageTypes = new HashSet<>();
+ this.messageTypes.add((byte) 0x05);
+ }
+
+ @Override
+ public void onMessage(Channel channel, byte[] data) {
+ messageCalled[0] = true;
+ }
+ };
+ Parameter.handlerMap.put((byte) 0x05, handler);
+
+ final boolean[] connectCalled = {false};
+ P2pEventHandler connectHandler = new P2pEventHandler() {
+ {
+ this.messageTypes = new HashSet<>();
+ }
+
+ @Override
+ public void onConnect(Channel channel) {
+ connectCalled[0] = true;
+ }
+ };
+ Parameter.handlerList.add(connectHandler);
+
+ Channel channel = createChannelWithMockCtx("10.0.0.70", 100);
+ Parameter.p2pConfig.setMaxConnections(50);
+
+ byte[] data = new byte[]{0x05, 0x01, 0x02};
+ ChannelManager.processMessage(channel, data);
+
+ Assert.assertTrue(messageCalled[0]);
+ Assert.assertTrue(connectCalled[0]);
+ Assert.assertTrue(channel.isFinishHandshake());
+ }
+
+ @Test
+ public void testLogDisconnectReason() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.80", 100);
+ // Should not throw
+ ChannelManager.logDisconnectReason(channel, DisconnectReason.TOO_MANY_PEERS);
+ }
+
+ private Channel createChannelWithAddress(String ip, int port) throws Exception {
+ Channel channel = new Channel();
+ InetSocketAddress addr = new InetSocketAddress(ip, port);
+ setFieldValue(channel, "inetSocketAddress", addr);
+ setFieldValue(channel, "inetAddress", addr.getAddress());
+ return channel;
+ }
+
+ private Channel createChannelWithMockCtx(String ip, int port) throws Exception {
+ Channel channel = new Channel();
+ InetSocketAddress addr = new InetSocketAddress(ip, port);
+ setFieldValue(channel, "inetSocketAddress", addr);
+ setFieldValue(channel, "inetAddress", addr.getAddress());
+
+ ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class);
+ io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class);
+ when(mockCtx.channel()).thenReturn(mockNettyChannel);
+ when(mockNettyChannel.remoteAddress()).thenReturn(addr);
+ ChannelFuture mockFuture = mock(ChannelFuture.class);
+ when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture);
+ when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture);
+ when(mockCtx.close()).thenReturn(mockFuture);
+ when(mockNettyChannel.close()).thenReturn(mockFuture);
+ setFieldValue(channel, "ctx", mockCtx);
+
+ return channel;
+ }
+
+ private void setFieldValue(Object obj, String fieldName, Object value) throws Exception {
+ Field field = obj.getClass().getDeclaredField(fieldName);
+ field.setAccessible(true);
+ field.set(obj, value);
+ }
+
+ private void setStaticField(Class> clazz, String fieldName, Object value) throws Exception {
+ Field field = clazz.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ field.set(null, value);
+ }
+}
diff --git a/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerTest.java b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerTest.java
new file mode 100644
index 00000000000..253651e7a99
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerTest.java
@@ -0,0 +1,129 @@
+package org.tron.p2p.connection;
+
+import java.lang.reflect.Field;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import lombok.extern.slf4j.Slf4j;
+import org.junit.Assert;
+import org.junit.Test;
+import org.tron.p2p.P2pConfig;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.connection.business.handshake.DisconnectCode;
+
+@Slf4j(topic = "net")
+public class ChannelManagerTest {
+
+ @Test
+ public synchronized void testGetConnectionNum() throws Exception {
+ Channel c1 = new Channel();
+ InetSocketAddress a1 = new InetSocketAddress("100.1.1.1", 100);
+ Field field = c1.getClass().getDeclaredField("inetAddress");
+ field.setAccessible(true);
+ field.set(c1, a1.getAddress());
+
+ Channel c2 = new Channel();
+ InetSocketAddress a2 = new InetSocketAddress("100.1.1.2", 100);
+ field = c2.getClass().getDeclaredField("inetAddress");
+ field.setAccessible(true);
+ field.set(c2, a2.getAddress());
+
+ Channel c3 = new Channel();
+ InetSocketAddress a3 = new InetSocketAddress("100.1.1.2", 99);
+ field = c3.getClass().getDeclaredField("inetAddress");
+ field.setAccessible(true);
+ field.set(c3, a3.getAddress());
+
+ int cnt = ChannelManager.getConnectionNum(a1.getAddress());
+ Assert.assertTrue(cnt == 0);
+
+ ChannelManager.getChannels().put(a1, c1);
+ cnt = ChannelManager.getConnectionNum(a1.getAddress());
+ Assert.assertTrue(cnt == 1);
+
+ ChannelManager.getChannels().put(a2, c2);
+ cnt = ChannelManager.getConnectionNum(a2.getAddress());
+ Assert.assertTrue(cnt == 1);
+
+ ChannelManager.getChannels().put(a3, c3);
+ cnt = ChannelManager.getConnectionNum(a3.getAddress());
+ Assert.assertTrue(cnt == 2);
+ }
+
+ @Test
+ public synchronized void testNotifyDisconnect() throws Exception {
+ Channel c1 = new Channel();
+ InetSocketAddress a1 = new InetSocketAddress("100.1.1.1", 100);
+
+ Field field = c1.getClass().getDeclaredField("inetSocketAddress");
+ field.setAccessible(true);
+ field.set(c1, a1);
+
+ InetAddress inetAddress = a1.getAddress();
+ field = c1.getClass().getDeclaredField("inetAddress");
+ field.setAccessible(true);
+ field.set(c1, inetAddress);
+
+ ChannelManager.getChannels().put(a1, c1);
+
+ Long time = ChannelManager.getBannedNodes().getIfPresent(a1.getAddress());
+ Assert.assertTrue(ChannelManager.getChannels().size() == 1);
+ Assert.assertTrue(time == null);
+
+ ChannelManager.notifyDisconnect(c1);
+ time = ChannelManager.getBannedNodes().getIfPresent(a1.getAddress());
+ Assert.assertTrue(time != null);
+ Assert.assertTrue(ChannelManager.getChannels().size() == 0);
+ }
+
+ @Test
+ public synchronized void testProcessPeer() throws Exception {
+ clearChannels();
+ Parameter.p2pConfig = new P2pConfig();
+
+ Channel c1 = new Channel();
+ InetSocketAddress a1 = new InetSocketAddress("100.1.1.2", 100);
+
+ Field field = c1.getClass().getDeclaredField("inetSocketAddress");
+ field.setAccessible(true);
+ field.set(c1, a1);
+ field = c1.getClass().getDeclaredField("inetAddress");
+ field.setAccessible(true);
+ field.set(c1, a1.getAddress());
+
+ DisconnectCode code = ChannelManager.processPeer(c1);
+ Assert.assertTrue(code.equals(DisconnectCode.NORMAL));
+
+ Thread.sleep(5);
+
+ Parameter.p2pConfig.setMaxConnections(1);
+
+ Channel c2 = new Channel();
+ InetSocketAddress a2 = new InetSocketAddress("100.1.1.2", 99);
+
+ field = c2.getClass().getDeclaredField("inetSocketAddress");
+ field.setAccessible(true);
+ field.set(c2, a2);
+ field = c2.getClass().getDeclaredField("inetAddress");
+ field.setAccessible(true);
+ field.set(c2, a2.getAddress());
+
+ code = ChannelManager.processPeer(c2);
+ Assert.assertTrue(code.equals(DisconnectCode.TOO_MANY_PEERS));
+
+ Parameter.p2pConfig.setMaxConnections(2);
+ Parameter.p2pConfig.setMaxConnectionsWithSameIp(1);
+ code = ChannelManager.processPeer(c2);
+ Assert.assertTrue(code.equals(DisconnectCode.MAX_CONNECTION_WITH_SAME_IP));
+
+ Parameter.p2pConfig.setMaxConnectionsWithSameIp(2);
+ c1.setNodeId("cc");
+ c2.setNodeId("cc");
+ code = ChannelManager.processPeer(c2);
+ Assert.assertTrue(code.equals(DisconnectCode.DUPLICATE_PEER));
+ }
+
+ private void clearChannels() {
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ }
+}
diff --git a/p2p/src/test/java/org/tron/p2p/connection/ChannelTest.java b/p2p/src/test/java/org/tron/p2p/connection/ChannelTest.java
new file mode 100644
index 00000000000..e901098b8c8
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/connection/ChannelTest.java
@@ -0,0 +1,357 @@
+package org.tron.p2p.connection;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.CorruptedFrameException;
+import io.netty.handler.timeout.ReadTimeoutException;
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.tron.p2p.P2pConfig;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.connection.message.handshake.HelloMessage;
+import org.tron.p2p.exception.P2pException;
+
+public class ChannelTest {
+
+ private Channel channel;
+ private ChannelHandlerContext mockCtx;
+ private io.netty.channel.Channel mockNettyChannel;
+
+ @Before
+ public void setUp() {
+ Parameter.p2pConfig = new P2pConfig();
+ channel = new Channel();
+ mockCtx = mock(ChannelHandlerContext.class);
+ mockNettyChannel = mock(io.netty.channel.Channel.class);
+ when(mockCtx.channel()).thenReturn(mockNettyChannel);
+ }
+
+ @After
+ public void tearDown() {
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ }
+
+ @Test
+ public void testInitWithNodeId() throws Exception {
+ io.netty.channel.ChannelPipeline mockPipeline =
+ mock(io.netty.channel.ChannelPipeline.class);
+ when(mockPipeline.addLast(
+ org.mockito.Mockito.anyString(), org.mockito.Mockito.any()))
+ .thenReturn(mockPipeline);
+
+ channel.init(mockPipeline, "abc123", false);
+ Assert.assertTrue(channel.isActive());
+ Assert.assertFalse(channel.isDiscoveryMode());
+ Assert.assertEquals("abc123", channel.getNodeId());
+ }
+
+ @Test
+ public void testInitWithEmptyNodeId() throws Exception {
+ io.netty.channel.ChannelPipeline mockPipeline =
+ mock(io.netty.channel.ChannelPipeline.class);
+ when(mockPipeline.addLast(
+ org.mockito.Mockito.anyString(), org.mockito.Mockito.any()))
+ .thenReturn(mockPipeline);
+
+ channel.init(mockPipeline, "", false);
+ Assert.assertFalse(channel.isActive());
+ }
+
+ @Test
+ public void testInitWithDiscoveryMode() throws Exception {
+ io.netty.channel.ChannelPipeline mockPipeline =
+ mock(io.netty.channel.ChannelPipeline.class);
+ when(mockPipeline.addLast(
+ org.mockito.Mockito.anyString(), org.mockito.Mockito.any()))
+ .thenReturn(mockPipeline);
+
+ channel.init(mockPipeline, "nodeId", true);
+ Assert.assertTrue(channel.isDiscoveryMode());
+ Assert.assertTrue(channel.isActive());
+ }
+
+ @Test
+ public void testSetChannelHandlerContext() {
+ InetSocketAddress address = new InetSocketAddress("192.168.1.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+
+ channel.setChannelHandlerContext(mockCtx);
+
+ Assert.assertEquals(mockCtx, channel.getCtx());
+ Assert.assertEquals(address, channel.getInetSocketAddress());
+ Assert.assertEquals(address.getAddress(), channel.getInetAddress());
+ Assert.assertFalse(channel.isTrustPeer());
+ }
+
+ @Test
+ public void testSetChannelHandlerContextWithTrustNode() {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ Parameter.p2pConfig.getTrustNodes().add(address.getAddress());
+
+ channel.setChannelHandlerContext(mockCtx);
+
+ Assert.assertTrue(channel.isTrustPeer());
+ Parameter.p2pConfig.getTrustNodes().clear();
+ }
+
+ @Test
+ public void testSetHelloMessage() throws Exception {
+ HelloMessage helloMsg = new HelloMessage(
+ org.tron.p2p.connection.business.handshake.DisconnectCode.NORMAL,
+ System.currentTimeMillis());
+
+ channel.setHelloMessage(helloMsg);
+
+ Assert.assertEquals(helloMsg, channel.getHelloMessage());
+ Assert.assertNotNull(channel.getNode());
+ Assert.assertNotNull(channel.getNodeId());
+ }
+
+ @Test
+ public void testProcessExceptionReadTimeout() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ ReadTimeoutException ex = ReadTimeoutException.INSTANCE;
+ channel.processException(ex);
+
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public void testProcessExceptionIOException() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ IOException ex = new IOException("connection reset");
+ channel.processException(ex);
+
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public void testProcessExceptionCorruptedFrame() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ CorruptedFrameException ex = new CorruptedFrameException("bad frame");
+ channel.processException(ex);
+
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public void testProcessExceptionP2pException() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ P2pException ex = new P2pException(P2pException.TypeEnum.BAD_MESSAGE, "test");
+ channel.processException(ex);
+
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public void testProcessExceptionGeneric() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ RuntimeException ex = new RuntimeException("unknown error");
+ channel.processException(ex);
+
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public void testProcessExceptionWithCausalLoop() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ // Create a causal loop: ex1 -> ex2 -> ex1
+ Exception ex1 = new Exception("loop1");
+ Exception ex2 = new Exception("loop2", ex1);
+ ex1.initCause(ex2);
+
+ channel.processException(ex1);
+ Assert.assertTrue(channel.isDisconnect());
+ }
+
+ @Test
+ public void testSendByteArrayWhenDisconnected() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ setCtxField(channel, mockCtx);
+ setInetSocketAddressField(channel, address);
+
+ channel.setDisconnect(true);
+ channel.send(new byte[]{0x01, 0x02});
+ // Should return early without writing; no NPE
+ }
+
+ @Test
+ public void testSendByteArraySuccess() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ ChannelFuture mockFuture = mock(ChannelFuture.class);
+ when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture);
+ when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture);
+ setCtxField(channel, mockCtx);
+ setInetSocketAddressField(channel, address);
+
+ channel.send(new byte[]{0x01, 0x02});
+ verify(mockCtx).writeAndFlush(org.mockito.Mockito.any());
+ }
+
+ @Test
+ public void testSendByteArrayException() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockCtx.writeAndFlush(org.mockito.Mockito.any()))
+ .thenThrow(new RuntimeException("write error"));
+ when(mockNettyChannel.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetSocketAddressField(channel, address);
+
+ channel.send(new byte[]{0x01, 0x02});
+ verify(mockNettyChannel).close();
+ }
+
+ @Test
+ public void testUpdateAvgLatency() {
+ channel.updateAvgLatency(100);
+ Assert.assertEquals(100, channel.getAvgLatency());
+
+ channel.updateAvgLatency(200);
+ Assert.assertEquals(150, channel.getAvgLatency());
+
+ channel.updateAvgLatency(300);
+ Assert.assertEquals(200, channel.getAvgLatency());
+ }
+
+ @Test
+ public void testCloseWithBanTime() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ channel.close(5000L);
+
+ Assert.assertTrue(channel.isDisconnect());
+ Assert.assertTrue(channel.getDisconnectTime() > 0);
+ Assert.assertNotNull(ChannelManager.getBannedNodes().getIfPresent(address.getAddress()));
+ verify(mockCtx).close();
+ }
+
+ @Test
+ public void testCloseDefaultBanTime() throws Exception {
+ InetSocketAddress address = new InetSocketAddress("10.0.0.2", 8080);
+ when(mockNettyChannel.remoteAddress()).thenReturn(address);
+ when(mockCtx.close()).thenReturn(mock(ChannelFuture.class));
+ setCtxField(channel, mockCtx);
+ setInetAddressField(channel, address);
+
+ channel.close();
+
+ Assert.assertTrue(channel.isDisconnect());
+ verify(mockCtx).close();
+ }
+
+ @Test
+ public void testEqualsAndHashCode() throws Exception {
+ Channel ch1 = new Channel();
+ Channel ch2 = new Channel();
+ InetSocketAddress addr = new InetSocketAddress("1.2.3.4", 100);
+
+ setInetSocketAddressField(ch1, addr);
+ setInetSocketAddressField(ch2, addr);
+
+ Assert.assertEquals(ch1, ch2);
+ Assert.assertEquals(ch1.hashCode(), ch2.hashCode());
+
+ Assert.assertTrue(ch1.equals(ch1));
+ Assert.assertFalse(ch1.equals(null));
+ Assert.assertFalse(ch1.equals("not a channel"));
+ }
+
+ @Test
+ public void testEqualsDifferentAddress() throws Exception {
+ Channel ch1 = new Channel();
+ Channel ch2 = new Channel();
+ setInetSocketAddressField(ch1, new InetSocketAddress("1.2.3.4", 100));
+ setInetSocketAddressField(ch2, new InetSocketAddress("1.2.3.5", 100));
+
+ Assert.assertNotEquals(ch1, ch2);
+ }
+
+ @Test
+ public void testToStringWithNodeId() throws Exception {
+ InetSocketAddress addr = new InetSocketAddress("1.2.3.4", 100);
+ setInetSocketAddressField(channel, addr);
+ channel.setNodeId("abcdef");
+
+ String result = channel.toString();
+ Assert.assertTrue(result.contains("abcdef"));
+ Assert.assertTrue(result.contains("1.2.3.4"));
+ }
+
+ @Test
+ public void testToStringWithoutNodeId() throws Exception {
+ InetSocketAddress addr = new InetSocketAddress("1.2.3.4", 100);
+ setInetSocketAddressField(channel, addr);
+ channel.setNodeId("");
+
+ String result = channel.toString();
+ Assert.assertTrue(result.contains(""));
+ }
+
+ private void setCtxField(Channel ch, ChannelHandlerContext ctx) throws Exception {
+ Field field = ch.getClass().getDeclaredField("ctx");
+ field.setAccessible(true);
+ field.set(ch, ctx);
+ }
+
+ private void setInetAddressField(Channel ch, InetSocketAddress addr) throws Exception {
+ Field inetField = ch.getClass().getDeclaredField("inetAddress");
+ inetField.setAccessible(true);
+ inetField.set(ch, addr.getAddress());
+ Field inetSockField = ch.getClass().getDeclaredField("inetSocketAddress");
+ inetSockField.setAccessible(true);
+ inetSockField.set(ch, addr);
+ }
+
+ private void setInetSocketAddressField(Channel ch, InetSocketAddress addr) throws Exception {
+ Field field = ch.getClass().getDeclaredField("inetSocketAddress");
+ field.setAccessible(true);
+ field.set(ch, addr);
+ }
+}
diff --git a/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceExtraTest.java b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceExtraTest.java
new file mode 100644
index 00000000000..0583e2015bb
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceExtraTest.java
@@ -0,0 +1,174 @@
+package org.tron.p2p.connection;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.tron.p2p.P2pConfig;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.connection.business.pool.ConnPoolService;
+
+public class ConnPoolServiceExtraTest {
+
+ private ConnPoolService connPoolService;
+
+ @Before
+ public void setUp() {
+ Parameter.p2pConfig = new P2pConfig();
+ Parameter.handlerList = new ArrayList<>();
+ Parameter.handlerMap = new java.util.HashMap<>();
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ connPoolService = new ConnPoolService();
+ }
+
+ @After
+ public void tearDown() {
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ Parameter.handlerList = new ArrayList<>();
+ }
+
+ @Test
+ public void testOnConnectPassive() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.1", 100, false);
+ connPoolService.onConnect(channel);
+ Assert.assertEquals(1, connPoolService.getPassivePeersCount().get());
+ Assert.assertEquals(0, connPoolService.getActivePeersCount().get());
+ }
+
+ @Test
+ public void testOnConnectActive() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.2", 100, true);
+ connPoolService.onConnect(channel);
+ Assert.assertEquals(0, connPoolService.getPassivePeersCount().get());
+ Assert.assertEquals(1, connPoolService.getActivePeersCount().get());
+ }
+
+ @Test
+ public void testOnConnectDuplicate() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.3", 100, false);
+ connPoolService.onConnect(channel);
+ connPoolService.onConnect(channel); // duplicate add
+ Assert.assertEquals(1, connPoolService.getPassivePeersCount().get());
+ }
+
+ @Test
+ public void testOnDisconnectPassive() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.4", 100, false);
+ connPoolService.onConnect(channel);
+ Assert.assertEquals(1, connPoolService.getPassivePeersCount().get());
+
+ connPoolService.onDisconnect(channel);
+ Assert.assertEquals(0, connPoolService.getPassivePeersCount().get());
+ }
+
+ @Test
+ public void testOnDisconnectActive() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.5", 100, true);
+ connPoolService.onConnect(channel);
+ Assert.assertEquals(1, connPoolService.getActivePeersCount().get());
+
+ connPoolService.onDisconnect(channel);
+ Assert.assertEquals(0, connPoolService.getActivePeersCount().get());
+ }
+
+ @Test
+ public void testOnDisconnectNotInList() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.6", 100, false);
+ // Disconnect without connect first
+ connPoolService.onDisconnect(channel);
+ Assert.assertEquals(0, connPoolService.getPassivePeersCount().get());
+ }
+
+ @Test
+ public void testOnMessage() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.7", 100, false);
+ connPoolService.onMessage(channel, new byte[]{0x01});
+ // Should do nothing
+ }
+
+ @Test
+ public void testTriggerConnectConfigActiveNode() throws Exception {
+ InetSocketAddress addr = new InetSocketAddress("10.0.0.8", 100);
+ Parameter.p2pConfig.getActiveNodes().add(addr);
+
+ // Recreate ConnPoolService so configActiveNodes includes the address added above
+ connPoolService = new ConnPoolService();
+
+ connPoolService.triggerConnect(addr);
+ // Should return early because it's a config active node
+ // connectingPeersCount should not change
+ Assert.assertEquals(0, connPoolService.getConnectingPeersCount().get());
+
+ Parameter.p2pConfig.getActiveNodes().clear();
+ }
+
+ @Test
+ public void testTriggerConnectNonConfigNode() throws Exception {
+ InetSocketAddress addr = new InetSocketAddress("10.0.0.9", 100);
+ connPoolService.getConnectingPeersCount().set(5);
+
+ // This will decrement connecting peers count
+ connPoolService.triggerConnect(addr);
+ Assert.assertEquals(4, connPoolService.getConnectingPeersCount().get());
+ }
+
+ @Test
+ public void testClose() throws Exception {
+ // Add an active peer that is not disconnected
+ Channel channel = createChannelWithMockCtx("10.0.0.10", 100, false);
+ connPoolService.onConnect(channel);
+
+ connPoolService.close();
+ // Should send disconnect to all active peers and shutdown executors
+ }
+
+ @Test
+ public void testCloseAlreadyDisconnected() throws Exception {
+ Channel channel = createChannelWithMockCtx("10.0.0.11", 100, false);
+ channel.setDisconnect(true);
+ connPoolService.onConnect(channel);
+
+ connPoolService.close();
+ // Should skip sending disconnect to already disconnected channels
+ }
+
+ private Channel createChannelWithMockCtx(
+ String ip, int port, boolean active) throws Exception {
+ Channel channel = new Channel();
+ InetSocketAddress addr = new InetSocketAddress(ip, port);
+ setFieldValue(channel, "inetSocketAddress", addr);
+ setFieldValue(channel, "inetAddress", addr.getAddress());
+ if (active) {
+ setFieldValue(channel, "isActive", true);
+ }
+
+ ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class);
+ io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class);
+ when(mockCtx.channel()).thenReturn(mockNettyChannel);
+ when(mockNettyChannel.remoteAddress()).thenReturn(addr);
+ ChannelFuture mockFuture = mock(ChannelFuture.class);
+ when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture);
+ when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture);
+ when(mockCtx.close()).thenReturn(mockFuture);
+ when(mockNettyChannel.close()).thenReturn(mockFuture);
+ setFieldValue(channel, "ctx", mockCtx);
+
+ return channel;
+ }
+
+ private void setFieldValue(Object obj, String fieldName, Object value) throws Exception {
+ Field field = obj.getClass().getDeclaredField(fieldName);
+ field.setAccessible(true);
+ field.set(obj, value);
+ }
+}
diff --git a/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceTest.java b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceTest.java
new file mode 100644
index 00000000000..9be2411a0c8
--- /dev/null
+++ b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceTest.java
@@ -0,0 +1,128 @@
+package org.tron.p2p.connection;
+
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.tron.p2p.P2pConfig;
+import org.tron.p2p.base.Parameter;
+import org.tron.p2p.connection.business.pool.ConnPoolService;
+import org.tron.p2p.discover.Node;
+import org.tron.p2p.discover.NodeManager;
+
+public class ConnPoolServiceTest {
+
+ private static String localIp = "127.0.0.1";
+ private static int port = 10000;
+
+ @BeforeClass
+ public static void init() {
+ Parameter.p2pConfig = new P2pConfig();
+ Parameter.p2pConfig.setDiscoverEnable(false);
+ Parameter.p2pConfig.setPort(port);
+
+ NodeManager.init();
+ ChannelManager.init();
+ }
+
+ private void clearChannels() {
+ ChannelManager.getChannels().clear();
+ ChannelManager.getBannedNodes().invalidateAll();
+ }
+
+ @Test
+ public void getNodes_chooseHomeNode() {
+ InetSocketAddress localAddress =
+ new InetSocketAddress(Parameter.p2pConfig.getIp(), Parameter.p2pConfig.getPort());
+ Set inetInUse = new HashSet<>();
+ inetInUse.add(localAddress);
+
+ List connectableNodes = new ArrayList<>();
+ connectableNodes.add(NodeManager.getHomeNode());
+
+ ConnPoolService connPoolService = new ConnPoolService();
+ List nodes = connPoolService.getNodes(new HashSet<>(), inetInUse, connectableNodes, 1);
+ Assert.assertEquals(0, nodes.size());
+
+ nodes = connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, 1);
+ Assert.assertEquals(1, nodes.size());
+ }
+
+ @Test
+ public void getNodes_orderByUpdateTimeDesc() throws Exception {
+ clearChannels();
+ Node node1 = new Node(new InetSocketAddress(localIp, 90));
+ Field field = node1.getClass().getDeclaredField("updateTime");
+ field.setAccessible(true);
+ field.set(node1, System.currentTimeMillis());
+
+ Node node2 = new Node(new InetSocketAddress(localIp, 100));
+ field = node2.getClass().getDeclaredField("updateTime");
+ field.setAccessible(true);
+ field.set(node2, System.currentTimeMillis() + 10);
+
+ Assert.assertTrue(node1.getUpdateTime() < node2.getUpdateTime());
+
+ List connectableNodes = new ArrayList<>();
+ connectableNodes.add(node1);
+ connectableNodes.add(node2);
+
+ ConnPoolService connPoolService = new ConnPoolService();
+ List nodes =
+ connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, 2);
+ Assert.assertEquals(2, nodes.size());
+ Assert.assertTrue(nodes.get(0).getUpdateTime() > nodes.get(1).getUpdateTime());
+
+ int limit = 1;
+ List nodes2 =
+ connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, limit);
+ Assert.assertEquals(limit, nodes2.size());
+ }
+
+ @Test
+ public void getNodes_banNode() throws InterruptedException {
+ clearChannels();
+ InetSocketAddress inetSocketAddress = new InetSocketAddress(localIp, 90);
+ long banTime = 500L;
+ ChannelManager.banNode(inetSocketAddress.getAddress(), banTime);
+ Node node = new Node(inetSocketAddress);
+ List