Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions Common/Diagnostics/Logger.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Reflection;
using Amethyst.Common.Utility;
using System.Reflection;

namespace Amethyst.Common.Diagnostics
{
Expand All @@ -10,43 +11,28 @@ public static void WriteLine(object? message = null, ConsoleColor color = Consol
Console.WriteLine(message);
Console.ResetColor();
}

public static void Info(object? message)
public static void Info(object? message, CursorLocation? location = null)
{
WriteLine($"{(
Assembly.GetEntryAssembly()?.GetName()?.Name is { } name ?
$"[{name}] " :
"")}[INFO] {message}", ConsoleColor.White);
WriteLine($"{(location is not null && location.File != "<unknown>" ? location.ToString() + ": " : (Assembly.GetEntryAssembly()?.GetName().Name is not string name ? "Unknown: " : name.Trim() + ": "))} message: {message}", ConsoleColor.White);
}

public static void Debug(object? message)
public static void Debug(object? message, CursorLocation? location = null)
{
#if DEBUG
WriteLine($"{(
Assembly.GetEntryAssembly()?.GetName()?.Name is { } name ?
$"[{name}] " :
"")}[DEBUG] {message}", ConsoleColor.White);
WriteLine($"{(location is not null && location.File != "<unknown>" ? location.ToString() + ": " : (Assembly.GetEntryAssembly()?.GetName().Name is not string name ? "Unknown: " : name.Trim() + ": "))}message: {message}", ConsoleColor.White);
#endif
}

public static void Warn(string message) =>
WriteLine($"{(
Assembly.GetEntryAssembly()?.GetName()?.Name is { } name ?
$"[{name}] " :
"")}[WARN] {message}", ConsoleColor.Yellow);
public static void Warn(string message, CursorLocation? location = null) =>
WriteLine($"{(location is not null && location.File != "<unknown>" ? location?.ToString() + ": " : (Assembly.GetEntryAssembly()?.GetName().Name is not string name ? "Unknown: " : name.Trim() + ": "))}warning: {message}", ConsoleColor.Yellow);

public static void Error(string message) =>
WriteLine($"{(
Assembly.GetEntryAssembly()?.GetName()?.Name is { } name ?
$"[{name}] " :
"")}[ERROR] {message}", ConsoleColor.Red);
public static void Error(string message, CursorLocation? location = null) =>
WriteLine($"{(location is not null && location.File != "<unknown>" ? location?.ToString() + ": " : (Assembly.GetEntryAssembly()?.GetName().Name is not string name ? "Unknown: " : name.Trim() + ": "))}error: {message}", ConsoleColor.Red);

public static void Fatal(string message)
public static void Fatal(string message, CursorLocation? location = null)
{
WriteLine($"{(
Assembly.GetEntryAssembly()?.GetName()?.Name is { } name ?
$"[{name}] " :
"")}[FATAL] {message}", ConsoleColor.Magenta);
WriteLine($"{(location is not null && location.File != "<unknown>" ? location?.ToString() + ": " : (Assembly.GetEntryAssembly()?.GetName().Name is not string name ? "Unknown: " : name.Trim() + ": "))}fatal error: {message}", ConsoleColor.Magenta);
Environment.Exit(1);
}
}
Expand Down
3 changes: 3 additions & 0 deletions Common/Models/VariableSymbolModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@ public class VariableSymbolModel

[JsonProperty("address")]
public string Address { get; set; } = string.Empty;

[JsonProperty("is_vaddress")]
public bool IsVirtualTableAddress { get; set; } = false;
}
}
3 changes: 3 additions & 0 deletions Common/Models/VirtualFunctionSymbolModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ public class VirtualFunctionSymbolModel

[JsonProperty("index")]
public uint Index { get; set; } = 0;

[JsonProperty("is_vdtor")]
public bool IsVirtualDestructor { get; set; } = false;
}
}
2 changes: 1 addition & 1 deletion Common/Tracking/FileTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public FileTracker(DirectoryInfo inputDirectory, FileInfo checksumFile, string[]
if (Filters.Any() && !Filters.Any(f => Path.GetRelativePath(InputDirectory.FullName, file.FullName).StartsWith(f)))
continue;
string filePath = file.FullName.NormalizeSlashes();
#if !DEBUG
#if DEBUG
string content = File.ReadAllText(file.FullName);
ulong hash = XXH64.DigestOf(Encoding.UTF8.GetBytes(content));

Expand Down
27 changes: 27 additions & 0 deletions Common/Utility/CursorLocation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Amethyst.Common.Extensions;

namespace Amethyst.Common.Utility {
public class CursorLocation
{
public string File { get; set; }
public uint Line { get; set; }
public uint Column { get; set; }

public CursorLocation(string file, uint line, uint column)
{
if (string.IsNullOrEmpty(file) || !System.IO.File.Exists(file))
File = "<unknown>";
else
File = Path.GetFullPath(file).NormalizeSlashes();
Line = line;
Column = column;
}

override public string ToString()
{
if (File == "<unknown>")
return "";
return $"{File}({Line},{Column})";
}
}
}
19 changes: 19 additions & 0 deletions Common/Utility/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,24 @@ public static void CreateDefinitionFile(string defFile, IEnumerable<string> mang
sb.AppendLine("; End of generated file.");
File.WriteAllText(defFile, sb.ToString());
}

public static void WritePrefixedString(this BinaryWriter writer, string str) {
byte[] bytes = Encoding.UTF8.GetBytes(str);
writer.Write(bytes.Length);
writer.Write(bytes);
}

public static string ReadPrefixedString(this BinaryReader reader) {
int length = reader.ReadInt32();
byte[] bytes = reader.ReadBytes(length);
return Encoding.UTF8.GetString(bytes);
}

public static void Align(this BinaryWriter writer, int alignment = 16, byte pad = 0x00) {
long pos = writer.BaseStream.Position;
int padding = (int)((alignment - (pos % alignment)) % alignment);
for (int i = 0; i < padding; i++)
writer.Write(pad);
}
}
}
10 changes: 5 additions & 5 deletions ModuleTweaker/Amethyst.ModuleTweaker.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
<Nullable>enable</Nullable>
<AssemblyName>Amethyst.ModuleTweaker</AssemblyName>
<RootNamespace>Amethyst.ModuleTweaker</RootNamespace>
<Version>1.0.6</Version>
<Version>2.0.0</Version>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Common\Amethyst.Common.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="AsmResolver" Version="5.5.1" />
<PackageReference Include="AsmResolver.PE" Version="5.5.1" />
<PackageReference Include="AsmResolver.PE.File" Version="5.5.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Common\Amethyst.Common.csproj" />
</ItemGroup>

</Project>
132 changes: 97 additions & 35 deletions ModuleTweaker/Commands/MainCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,65 @@
using Amethyst.Common.Models;
using Amethyst.ModuleTweaker.Patching;
using AsmResolver.PE.File;
using AsmResolver.PE.Imports;
using CliFx;
using CliFx.Attributes;
using CliFx.Infrastructure;
using K4os.Hash.xxHash;
using Newtonsoft.Json;
using System.Globalization;

namespace Amethyst.ModuleTweaker.Commands
{
[Command(Description = "Patches or unpatches modules for runtime importing support.")]
public class MainCommand : ICommand
{
[CommandOption("module", 'm', Description = "The specified module path to patch.")]
[CommandOption("module", 'm', Description = "The specified module path to patch.", IsRequired = true)]
public string ModulePath { get; set; } = null!;

[CommandOption("symbols", 's', Description = "Path to directory containing *.symbols.json to use for patching.")]
[CommandOption("symbols", 's', Description = "Path to directory containing *.symbols.json to use for patching.", IsRequired = true)]
public string SymbolsPath { get; set; } = null!;

[CommandOption("output", 'o', Description = "Path to save temporary files, don't confuse with -m.")]
public string OutputPath { get; set; } = null!;

public ValueTask ExecuteAsync(IConsole console)
{
FileInfo module = new(ModulePath);
DirectoryInfo symbols = new(SymbolsPath);
if (module.Exists is false)
{
Logger.Warn("Couldn't patch module, specified module does not exist.");
DirectoryInfo symbolsDir = new(SymbolsPath);
if (module.Exists is false) {
Logger.Fatal("Couldn't patch module, specified module does not exist.");
return default;
}

if (symbols.Exists is false)
{
Logger.Warn("Couldn't patch module, specified symbols directory does not exist.");
if (symbolsDir.Exists is false) {
Logger.Fatal("Couldn't patch module, specified symbols directory does not exist.");
return default;
}

if (string.IsNullOrEmpty(OutputPath)) {
OutputPath = Path.GetFullPath(Path.Combine(SymbolsPath, "../"));
}
DirectoryInfo outDir = new(OutputPath);

ulong ParseAddress(string? address)
{
if (string.IsNullOrEmpty(address))
return 0x0;
if (address.StartsWith("0x", StringComparison.OrdinalIgnoreCase))
address = address[2..];
if (!ulong.TryParse(address, NumberStyles.HexNumber, null, out var addr))
return 0x0;
return addr;
}

SymbolFactory.Register(new SymbolType(1, "pe32+", "data"), () => new Patching.PE.V1.PEDataSymbol());
SymbolFactory.Register(new SymbolType(1, "pe32+", "function"), () => new Patching.PE.V1.PEFunctionSymbol());
HeaderFactory.Register(new HeaderType(1, "pe32+"), (args) => new Patching.PE.V1.PEImporterHeader());

// Collect all symbol files and accumulate mangled names
IEnumerable<FileInfo> symbolFiles = symbols.EnumerateFiles("*.json", SearchOption.AllDirectories);
HashSet<FunctionSymbolModel> methods = [];
HashSet<VariableSymbolModel> variables = [];
HashSet<VirtualTableSymbolModel> vtables = [];
HashSet<VirtualFunctionSymbolModel> vfuncs = [];
IEnumerable<FileInfo> symbolFiles = symbolsDir.EnumerateFiles("*.json", SearchOption.AllDirectories);
List<AbstractSymbol> symbols = [];
foreach (var symbolFile in symbolFiles)
{
using var stream = symbolFile.OpenRead();
Expand All @@ -50,29 +71,48 @@ public ValueTask ExecuteAsync(IConsole console)
switch (symbolJson.FormatVersion)
{
case 1:
foreach (var function in symbolJson.Functions)
{
foreach (var function in symbolJson.Functions) {
if (string.IsNullOrEmpty(function.Name))
continue;
methods.Add(function);
symbols.Add(new Patching.PE.V1.PEFunctionSymbol {
Name = function.Name,
IsVirtual = false,
IsSignature = function.Signature is not null,
Address = ParseAddress(function.Address),
Signature = function.Signature ?? string.Empty
});
}
foreach (var variable in symbolJson.Variables)
{
if (string.IsNullOrEmpty(variable.Name))
foreach (var vfunc in symbolJson.VirtualFunctions) {
if (string.IsNullOrEmpty(vfunc.Name))
continue;
variables.Add(variable);
symbols.Add(new Patching.PE.V1.PEFunctionSymbol {
Name = vfunc.Name,
IsVirtual = true,
VirtualIndex = vfunc.Index,
VirtualTable = vfunc.VirtualTable ?? "this",
IsDestructor = vfunc.IsVirtualDestructor,
HasStorage = vfunc.IsVirtualDestructor
});
}
foreach (var vtable in symbolJson.VirtualTables)
{
if (string.IsNullOrEmpty(vtable.Name))
foreach (var variable in symbolJson.Variables) {
if (string.IsNullOrEmpty(variable.Name))
continue;
vtables.Add(vtable);
symbols.Add(new Patching.PE.V1.PEDataSymbol {
Name = variable.Name,
IsVirtualTable = false,
Address = ParseAddress(variable.Address),
IsVirtualTableAddress = variable.IsVirtualTableAddress,
HasStorage = variable.IsVirtualTableAddress
});
}
foreach (var vfunc in symbolJson.VirtualFunctions)
{
if (string.IsNullOrEmpty(vfunc.Name))
foreach (var vtable in symbolJson.VirtualTables) {
if (string.IsNullOrEmpty(vtable.Name))
continue;
vfuncs.Add(vfunc);
symbols.Add(new Patching.PE.V1.PEDataSymbol {
Name = vtable.Name,
IsVirtualTable = true,
Address = ParseAddress(vtable.Address)
});
}
break;
}
Expand All @@ -82,13 +122,35 @@ public ValueTask ExecuteAsync(IConsole console)
try
{
// Patch the module
var file = PEFile.FromFile(ModulePath);
PEFileHelper helper = new(file);
if (helper.Patch(methods, variables, vtables, vfuncs))
var bytes = File.ReadAllBytes(ModulePath);
ulong hash = XXH64.DigestOf(bytes);
if (File.Exists(Path.Combine(outDir.FullName, "module_hash.txt"))) {
var existingHash = File.ReadAllText(Path.Combine(outDir.FullName, "module_hash.txt"));
if (ulong.TryParse(existingHash, NumberStyles.HexNumber, null, out var existingHashValue)) {
if (existingHashValue == hash) {
Logger.Info("Module hash matches previous hash, skipping patch.");
return default;
}
}
}

var peFile = PEFile.FromBytes(bytes);
if (peFile is null) {
Logger.Fatal("Failed to read module as a PE file.");
return default;
}
Logger.Info($"Loaded module '{ModulePath}' as PE file.");
var patcher = new Patching.PE.PEPatcher(peFile, symbols);

if (patcher.Patch())
{
file.AlignSections();
File.Copy(ModulePath, ModulePath + ".backup", true);
file.Write(ModulePath);
File.Copy(ModulePath, ModulePath + ".bak", true);
using var ms = new MemoryStream();
peFile.Write(ms);
var newBytes = ms.ToArray();
ulong newHash = XXH64.DigestOf(newBytes);
File.WriteAllBytes(ModulePath, newBytes);
File.WriteAllText(Path.Combine(outDir.FullName, "module_hash.txt"), newHash.ToString("X16"));
}
}
catch (Exception ex)
Expand Down
Loading