Skip to content

Commit

Permalink
ui: support DNS SRV lookup & show resolved inetendpoint
Browse files Browse the repository at this point in the history
Signed-off-by: panxuesen <[email protected]>
  • Loading branch information
ailearncoder committed Dec 16, 2024
1 parent 4ba8794 commit cb3c174
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 14 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jsr305 = "com.google.code.findbugs:jsr305:3.0.2"
junit = "junit:junit:4.13.2"
kotlinx-coroutines-android = "org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.0"
zxing-android-embedded = "com.journeyapps:zxing-android-embedded:4.3.0"
dnsjava = "dnsjava:dnsjava:3.4.2"

[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
Expand Down
1 change: 1 addition & 0 deletions tunnel/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ android {
dependencies {
implementation(libs.androidx.annotation)
implementation(libs.androidx.collection)
implementation(libs.dnsjava)
compileOnly(libs.jsr305)
testImplementation(libs.junit)
}
Expand Down
86 changes: 72 additions & 14 deletions tunnel/src/main/java/com/wireguard/config/InetEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@

import com.wireguard.util.NonNullForAll;

import org.xbill.DNS.DClass;
import org.xbill.DNS.ExtendedResolver;
import org.xbill.DNS.Lookup;
import org.xbill.DNS.Record;
import org.xbill.DNS.Resolver;
import org.xbill.DNS.SRVRecord;
import org.xbill.DNS.SimpleResolver;
import org.xbill.DNS.TextParseException;
import org.xbill.DNS.Type;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.URI;
Expand All @@ -15,6 +25,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

import androidx.annotation.Nullable;
Expand Down Expand Up @@ -46,6 +57,11 @@ private InetEndpoint(final String host, final boolean isResolved, final int port
public static InetEndpoint parse(final String endpoint) throws ParseException {
if (FORBIDDEN_CHARACTERS.matcher(endpoint).find())
throw new ParseException(InetEndpoint.class, endpoint, "Forbidden characters");
if (endpoint.contains("_")) {
// SRV records
final String host = endpoint.split(":")[0];
return new InetEndpoint(host, false, 0);
}
final URI uri;
try {
uri = new URI("wg://" + endpoint);
Expand Down Expand Up @@ -92,21 +108,60 @@ public Optional<InetEndpoint> getResolved() {
return Optional.of(this);
synchronized (lock) {
//TODO(zx2c4): Implement a real timeout mechanism using DNS TTL
if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) {
try {
// Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues.
final InetAddress[] candidates = InetAddress.getAllByName(host);
InetAddress address = candidates[0];
for (final InetAddress candidate : candidates) {
if (candidate instanceof Inet4Address) {
address = candidate;
break;
final long ttlTimeout = Duration.between(lastResolution, Instant.now()).toSeconds();
if (ttlTimeout > 60) {
resolved = null;
final String[] target = {host};
final int[] targetPort = {port};
if (host.contains("_")) {
// SRV records
try {
final Lookup lookup = new Lookup(host, Type.SRV, DClass.IN);
final Resolver resolver1 = new SimpleResolver("223.5.5.5");
final Resolver resolver2 = new SimpleResolver("223.6.6.6");
final Resolver[] resolvers = {resolver1, resolver2};
final Resolver extendedResolver = new ExtendedResolver(resolvers);
lookup.setResolver(extendedResolver);
lookup.setCache(null);
final Record[] records = lookup.run();
if (lookup.getResult() == Lookup.SUCCESSFUL) {
for (final Record record : records) {
final SRVRecord srv = (SRVRecord) record;
try {
target[0] = srv.getTarget().toString(true);
targetPort[0] = srv.getPort();
InetAddresses.parse(target[0]);
// Parsing ths host as a numeric address worked, so we don't need to do DNS lookups.
resolved = new InetEndpoint(target[0], true, targetPort[0]);
} catch (final ParseException ignored) {
// Failed to parse the host as a numeric address, so it must be a DNS hostname/FQDN.
}
// use the first SRV record and break out of loop
break;
}
} else {
System.out.println("SRV lookup failed: " + lookup.getErrorString());
}
} catch (final TextParseException | UnknownHostException e) {
System.out.println("SRV lookup failed: " + e.getMessage());
}
}
if (resolved == null) {
try {
// Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues.
final InetAddress[] candidates = InetAddress.getAllByName(target[0]);
InetAddress address = candidates[0];
for (final InetAddress candidate : candidates) {
if (candidate instanceof Inet4Address) {
address = candidate;
break;
}
}
resolved = new InetEndpoint(address.getHostAddress(), true, targetPort[0]);
lastResolution = Instant.now();
} catch (final UnknownHostException e) {
System.out.println("DNS lookup failed: " + e.getMessage());
}
resolved = new InetEndpoint(address.getHostAddress(), true, port);
lastResolution = Instant.now();
} catch (final UnknownHostException e) {
resolved = null;
}
}
return Optional.ofNullable(resolved);
Expand All @@ -121,6 +176,9 @@ public int hashCode() {
@Override
public String toString() {
final boolean isBareIpv6 = isResolved && BARE_IPV6.matcher(host).matches();
return (isBareIpv6 ? '[' + host + ']' : host) + ':' + port;
// Only show the port if it's non-zero
if (port > 0)
return (isBareIpv6 ? '[' + host + ']' : host) + ':' + port;
return (isBareIpv6 ? '[' + host + ']' : host);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ class TunnelDetailFragment : BaseFragment(), MenuProvider {
for (i in 0 until binding.peersLayout.childCount) {
val peer: TunnelDetailPeerBinding = DataBindingUtil.getBinding(binding.peersLayout.getChildAt(i))
?: continue
if (binding.config != null && i < binding.config!!.peers.size) {
val endpoint = binding.config!!.peers[i].endpoint.get()
val resolved = endpoint.resolved.get()
if (resolved.host != endpoint.host) {
if (endpoint.port != 0) {
peer.endpointText.text = "${endpoint.host}:${endpoint.port}\n${resolved.host}:${resolved.port}"
} else {
peer.endpointText.text = "${endpoint.host}\n${resolved.host}:${resolved.port}"
}
}
}
val publicKey = peer.item!!.publicKey
val peerStats = statistics.peer(publicKey)
if (peerStats == null || (peerStats.rxBytes == 0L && peerStats.txBytes == 0L)) {
Expand Down

0 comments on commit cb3c174

Please sign in to comment.