using Mono.Cecil; using Mono.Cecil.Cil; using System; using System.Collections.Generic; using System.IO; using System.Linq; using UnityEngine; using UnityEditorInternal; using UnityEditor; using UnityEditor.Compilation; using UnityEngine.Assertions; using System.Text; namespace com.bbbirder.injection.editor { public static class InjectHelper { /// /// inject target assembly /// /// /// /// /// is written internal static bool InjectAssembly(InjectionInfo[] injections, string inputAssemblyPath, string outputAssemblyPath, bool isEditor, BuildTarget buildTarget) { // set up assembly resolver var resolver = new DefaultAssemblyResolver(); var apiCompatibilityLevel = PlayerSettings.GetApiCompatibilityLevel(EditorUserBuildSettings.selectedBuildTargetGroup); var assemblySearchFolders = UnityInjectUtils.GetAssemblySearchFolders(isEditor, buildTarget); var systemAssemblyDirectories = CompilationPipeline.GetSystemAssemblyDirectories(apiCompatibilityLevel); resolver.AddSearchDirectory(Path.GetDirectoryName(outputAssemblyPath)); foreach (var folder in assemblySearchFolders) { resolver.AddSearchDirectory(folder); } foreach (var folder in systemAssemblyDirectories) { resolver.AddSearchDirectory(folder); } var IsPlayerAssembly = Path.GetFullPath(inputAssemblyPath).StartsWith(Path.GetFullPath("Library/")) || Path.GetFullPath(inputAssemblyPath).StartsWith(Path.GetFullPath("Temp/")); var targetAssembly = AssemblyDefinition.ReadAssembly(inputAssemblyPath, new ReaderParameters() { AssemblyResolver = resolver, ReadingMode = ReadingMode.Immediate, ReadSymbols = IsPlayerAssembly, InMemory = true, }); foreach(var inj in injections) { var providerMethod = inj.DryInjectMethod; if (providerMethod is null) continue; var methodName = inj.DryInjectMethodName; var targetType = GetCorrespondingType(targetAssembly.MainModule, inj.DryInjectAssmemble); if (targetType == null) { throw new($"Cannot find Type `{inj.DryInjectAssmemble}` in target assembly {inputAssemblyPath}"); } var providerAssemblyPath = providerMethod.DeclaringType.Assembly.GetAssemblyPath(); var providerAssembly = AssemblyDefinition.ReadAssembly(providerAssemblyPath, new ReaderParameters() { AssemblyResolver = resolver, ReadingMode = ReadingMode.Immediate, ReadSymbols = IsPlayerAssembly, InMemory = true, }); var providerType = GetCorrespondingType(providerAssembly.MainModule, providerMethod.DeclaringType); var targetMethod = providerType.FindMethod(providerMethod.GetSignature()).Resolve(); if (targetMethod is null) { throw new($"Cannot find Method `{methodName}` in Type `{providerMethod.DeclaringType}`"); } //add field targetType.AddInjectField(targetMethod, methodName, true); } /* //mark check var injected = targetAssembly.MainModule.Types.Any(t => Constants.INJECTED_MARK_NAME == t.Name && Constants.INJECTED_MARK_NAMESPACE == t.Namespace); if (injected) { targetAssembly.Release(); return false; } */ foreach (var group in injections.GroupBy(inj => inj.InjectedMethod)) { var injectedMethod = group.Key; if (injectedMethod is null) continue; var type = injectedMethod.DeclaringType; var methodName = injectedMethod.Name; var targetType = GetCorrespondingType(targetAssembly.MainModule, type); if (targetType is null) { throw new($"Cannot find Type `{type}` in target assembly {inputAssemblyPath}"); } var targetMethod = targetType.FindMethod(injectedMethod.GetSignature()).Resolve(); if (targetMethod is null) { throw new($"Cannot find Method `{methodName}` in Type `{type}`"); } //add origin var originalMethod = targetType.DuplicateOriginalMethod(targetMethod); //add field var injectionName = Constants.GetInjectedFieldName(methodName, targetMethod.GetSignature()); var (field, fieldInvoke) = targetType.AddInjectField(targetMethod, injectionName); //add method targetType.AddInjectionMethod(targetMethod, originalMethod, field, fieldInvoke, methodName); } //mark make var InjectedMark = new TypeDefinition( Constants.INJECTED_MARK_NAMESPACE, Constants.INJECTED_MARK_NAME, TypeAttributes.Class, targetAssembly.MainModule.TypeSystem.Object); targetAssembly.MainModule.Types.Add(InjectedMark); targetAssembly.Write(outputAssemblyPath, new WriterParameters() { WriteSymbols = IsPlayerAssembly, }); targetAssembly.Release(); return true; static TypeDefinition GetCorrespondingType(ModuleDefinition module, Type t1) { var typeDefinition = default(TypeDefinition); foreach (var type in GetContainingTypes(t1).Reverse()) { if (typeDefinition is null) { typeDefinition = module.Types.FirstOrDefault(t => type.FullName == t.FullName); } else { typeDefinition = typeDefinition.NestedTypes.FirstOrDefault(t => t.Name == type.Name); } if (typeDefinition is null) return null; } return typeDefinition; static IEnumerable GetContainingTypes(Type type) { while (type != null) { yield return type; type = type.DeclaringType; } } } } static MethodDefinition DuplicateOriginalMethod(this TypeDefinition targetType, MethodDefinition targetMethod) { var originName = Constants.GetOriginMethodName(targetMethod.Name, targetMethod.GetSignature()); var duplicatedMethod = targetType.Methods.FirstOrDefault(m => m.Name == originName); if (duplicatedMethod is null) { duplicatedMethod = targetMethod.Clone(); duplicatedMethod.IsPrivate = true; duplicatedMethod.Name = originName; targetType.Methods.Add(duplicatedMethod); } return duplicatedMethod; } static void Release(this AssemblyDefinition assemblyDefinition) { if (assemblyDefinition == null) return; assemblyDefinition.MainModule.AssemblyResolver?.Dispose(); assemblyDefinition.MainModule.SymbolReader?.Dispose(); assemblyDefinition.Dispose(); } static (FieldDefinition field, MethodReference fieldInvokeMethod) AddInjectField(this TypeDefinition targetType, MethodDefinition targetMethod, string injectionName, bool isPublic=false) { var HasThis = targetMethod.HasThis; var Parameters = targetMethod.Parameters; var GenericParameters = targetMethod.GenericParameters; // var CustomAttributes = targetMethod.CustomAttributes; var ReturnType = targetMethod.ReturnType; var ReturnVoid = targetMethod.IsReturnVoid(); //define delegate // var delegateParameters = new List(); // if(HasThis) delegateParameters.Add(targetType); // foreach(var p in Parameters) delegateParameters.Add(p.ParameterType); // var delegateType = targetType.Module.CreateDelegateType(Settings.GetDelegateTypeName(methodName),targetType,ReturnType,delegateParameters); // targetType.NestedTypes.Add(delegateType); var genName = targetMethod.IsReturnVoid() ? "System.Action" : "System.Func"; var genPCnt = Parameters.Count; if (!ReturnVoid) genPCnt++; if (HasThis) genPCnt++; if (genPCnt > 0) { genName += "`" + genPCnt; } var rawGenType = targetType.Module.FindType(Type.GetType(genName)); var genType = targetType.Module.ImportReference(rawGenType); var genInst = new GenericInstanceType(genType); if (HasThis) genInst.GenericArguments.Add(targetType); foreach (var p in Parameters) genInst.GenericArguments.Add(p.ParameterType); if (!ReturnVoid) genInst.GenericArguments.Add(ReturnType); //store fields var sfldInject = targetType.Fields.FirstOrDefault(f => f.Name == injectionName); if (sfldInject is null) { sfldInject = new FieldDefinition(injectionName, FieldAttributes.Static | (isPublic ? FieldAttributes.Public : FieldAttributes.Private), genInst); targetType.Fields.Add(sfldInject); } // var sfldOrigin = new FieldDefinition(originName, // FieldAttributes.Private|FieldAttributes.Static|FieldAttributes.Assembly, // targetType.Module.ImportReference(typeof(Delegate))); // var resMth = genInst.Resolve(); var genMtd = rawGenType.FindMethodByName("Invoke"); // genMtd.DeclaringType = genInst; var mnlMth = new MethodReference(genMtd.Name, genMtd.ReturnType, genInst) { ExplicitThis = false, HasThis = true, CallingConvention = genMtd.CallingConvention }; foreach (var p in genMtd.Parameters) mnlMth.Parameters.Add(p); return (sfldInject, mnlMth); } static void AddInjectionMethod( this TypeDefinition targetType, MethodDefinition targetMethod, MethodDefinition originalMethod, FieldDefinition delegateField, MethodReference fieldInvoke, string methodName ) { var argidx = 0; var HasThis = targetMethod.HasThis; var Parameters = targetMethod.Parameters; // var GenericParameters = targetMethod.GenericParameters; // var CustomAttributes = targetMethod.CustomAttributes; var ReturnType = targetMethod.ReturnType; //redirect method if (!targetMethod.HasBody) { throw new ArgumentException($"method {targetMethod.Name} in type {targetType.Name} dont have a body"); } targetMethod.Body.Instructions.Clear(); var delegateType = delegateField.FieldType.Resolve(); var ilProcessor = targetMethod.Body.GetILProcessor(); var tagOp = Instruction.Create(OpCodes.Nop); //check null ilProcessor.Append(Instruction.Create(OpCodes.Ldsfld, delegateField)); ilProcessor.Append(Instruction.Create(OpCodes.Brtrue_S, tagOp)); // //set field // if(HasThis) // ilProcessor.Append(Instruction.Create(OpCodes.Ldarg_0)); // else // ilProcessor.Append(Instruction.Create(OpCodes.Ldnull)); // ilProcessor.Append(Instruction.Create(OpCodes.Ldftn, originalMethod)); // ilProcessor.Append(Instruction.Create(OpCodes.Newobj,delegateType.FindMethod(".ctor"))); // ilProcessor.Append(Instruction.Create(OpCodes.Stsfld,delegateField)); //invoke origin argidx = 0; if (HasThis) ilProcessor.Append(ilProcessor.createLdarg(argidx++)); for (var i = 0; i < Parameters.Count; i++) { var pType = Parameters[i].ParameterType; ilProcessor.Append(ilProcessor.createLdarg(argidx++)); // if(pType.IsValueType) // ilProcessor.Append(Instruction.Create(OpCodes.Box,pType)); } if (HasThis) ilProcessor.Append(Instruction.Create(OpCodes.Callvirt, originalMethod)); else ilProcessor.Append(Instruction.Create(OpCodes.Call, originalMethod)); // if(originalMethod.IsReturnVoid()) // ilProcessor.Append(Instruction.Create(OpCodes.Pop)); ilProcessor.Append(Instruction.Create(OpCodes.Ret)); //invoke ilProcessor.Append(tagOp); ilProcessor.Append(Instruction.Create(OpCodes.Ldsfld, delegateField)); argidx = 0; if (HasThis) ilProcessor.Append(ilProcessor.createLdarg(argidx++)); for (var i = 0; i < Parameters.Count; i++) { var pType = Parameters[i].ParameterType; ilProcessor.Append(ilProcessor.createLdarg(argidx++)); // if(pType.IsValueType) // ilProcessor.Append(Instruction.Create(OpCodes.Box,pType)); } ilProcessor.Append(Instruction.Create(OpCodes.Callvirt, (fieldInvoke))); // Fixes: conditional boxing here is unnecessary // if(ReturnType.IsComplexValueType()) // ilProcessor.Append(Instruction.Create(OpCodes.Box,ReturnType)); ilProcessor.Append(Instruction.Create(OpCodes.Nop)); ilProcessor.Append(Instruction.Create(OpCodes.Ret)); } static Instruction createLdarg(this ILProcessor ilProcessor, int i) { if (i < s_ldargs.Length) { return Instruction.Create(s_ldargs[i]); } else if (i < 256) { return ilProcessor.Create(OpCodes.Ldarg_S, (byte)i); } else { return ilProcessor.Create(OpCodes.Ldarg, (short)i); } } /// /// Create a clone of the given method definition /// public static MethodDefinition Clone(this MethodDefinition source) { var result = new MethodDefinition(source.Name, source.Attributes, source.ReturnType) { ImplAttributes = source.ImplAttributes, SemanticsAttributes = source.SemanticsAttributes, HasThis = source.HasThis, ExplicitThis = source.ExplicitThis, CallingConvention = source.CallingConvention }; foreach (var p in source.Parameters) result.Parameters.Add(p); // foreach (var p in source.CustomAttributes) result.CustomAttributes.Add(p); foreach (var p in source.GenericParameters) result.GenericParameters.Add(p); if (source.HasBody) { result.Body = source.Body.Clone(result); } return result; } /// /// Create a clone of the given method body /// public static MethodBody Clone(this MethodBody source, MethodDefinition target) { var result = new MethodBody(target) { InitLocals = source.InitLocals, MaxStackSize = source.MaxStackSize }; var worker = result.GetILProcessor(); if (source.HasVariables) { foreach (var v in source.Variables) { result.Variables.Add(v); } } foreach (var i in source.Instructions) { // Poor mans clone, but sufficient for our needs var clone = Instruction.Create(OpCodes.Nop); clone.OpCode = i.OpCode; clone.Operand = i.Operand; worker.Append(clone); } return result; } internal static bool IsReturnVoid(this MethodDefinition md) => md.ReturnType.ToString() == voidType.ToString(); internal static bool IsReturnValueType(this MethodDefinition md) => !md.IsReturnVoid() && md.ReturnType.IsValueType; internal static bool IsComplexValueType(this TypeReference td) => td.ToString() != voidType.ToString() && !td.IsPrimitive; internal static Type GetUnderlyingType(this TypeReference td) => td.IsPrimitive ? Type.GetType(td.Name) : objType; internal static string GetSignature(this MethodDefinition md) { var builder = new StringBuilder(); builder.Append(md.Name); if (md.HasGenericParameters) { builder.Append('`'); builder.Append(md.GenericParameters.Count); } builder.Append('('); if (md.HasParameters) { var parameters = md.Parameters; for (int i = 0; i < parameters.Count; i++) { ParameterDefinition parameterDefinition = parameters[i]; if (i > 0) { builder.Append(","); } AppendTypeFullName(builder, parameterDefinition.ParameterType); } } builder.Append(')'); return builder.ToString(); static void AppendTypeFullName(StringBuilder builder, TypeReference type) { if (!string.IsNullOrEmpty(type.Namespace)) { builder.Append(type.Namespace); builder.Append("::"); } var stack = new Stack(); var declaringType = type; while (null != declaringType) { stack.Push(declaringType); declaringType = declaringType.DeclaringType; } while (stack.TryPop(out var nestedType)) AppendNestedType(builder, nestedType); } static void AppendNestedType(StringBuilder builder, TypeReference type) { builder.Append(type.Name); if (type is GenericInstanceType gInst) { builder.Append('<'); var args = gInst.GenericArguments; for (int i = 0; i < args.Count; i++) { if (i != 0) { builder.Append(','); } AppendTypeFullName(builder, args[i]); } builder.Append('>'); } } } internal static MethodReference FindMethod(this TypeDefinition td, string methodSignature) { var method = td.Methods.FirstOrDefault(m => m.GetSignature().Equals(methodSignature)); if (method is null) return null; return td.Module.ImportReference(method); } internal static MethodReference FindMethodByName(this TypeDefinition td, string methodName) => td.Module.ImportReference(td.Methods.FirstOrDefault(m => m.Name.Equals(methodName))); internal static TypeDefinition FindType(this ModuleDefinition md, Type type) { DebugHelper.IsNotNull(type); HashSet knownAssemblyNames = new(); List modules = new(); GetModules(md); foreach (var m in modules) { DebugHelper.IsNotNull(m); var tp = m.GetType(type.Namespace, type.Name); if (null != tp) { return tp; } } return null; void GetModules(ModuleDefinition md) { if (knownAssemblyNames.Contains(md.FileName)) return; var refModules = md.AssemblyReferences .Select(an => { try { return md.AssemblyResolver.Resolve(an).MainModule; } catch { return null; } }) .Where(r => r != null) .ToArray(); AddModule(md); foreach (var m in refModules) { GetModules(m); } } void AddModule(ModuleDefinition md) { var fileName = md.FileName; if (!knownAssemblyNames.Contains(fileName)) modules.Add(md); knownAssemblyNames.Add(fileName); } // return new TypeReference(type.Namespace,type.Name,md,md.TypeSystem.CoreLibrary); } internal static TypeReference FindType(this ModuleDefinition md) { return FindType(md, typeof(T)); // return new TypeReference(typeof(T).Namespace,typeof(T).Name,md,md.TypeSystem.CoreLibrary); } internal static TypeDefinition CreateDelegateType(this ModuleDefinition assembly, string name, TypeDefinition declaringType, TypeReference returnType, IEnumerable parameters) { var voidType = assembly.TypeSystem.Void; var objectType = assembly.TypeSystem.Object; var nativeIntType = assembly.TypeSystem.IntPtr; var asyncResultType = assembly.FindType(); var asyncCallbackType = assembly.FindType(); var multicastDelegateType = assembly.FindType(); var DelegateTypeAttributes = TypeAttributes.NestedPublic | TypeAttributes.Sealed; var dt = new TypeDefinition("", name, DelegateTypeAttributes, multicastDelegateType); dt.DeclaringType = declaringType; // add constructor var ConstructorAttributes = MethodAttributes.Public | MethodAttributes.HideBySig | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName; var constructor = new MethodDefinition(".ctor", ConstructorAttributes, voidType); constructor.Parameters.Add(new ParameterDefinition("objectInstance", ParameterAttributes.None, objectType)); constructor.Parameters.Add(new ParameterDefinition("functionPtr", ParameterAttributes.None, nativeIntType)); constructor.ImplAttributes = MethodImplAttributes.Runtime; dt.Methods.Add(constructor); // add BeginInvoke var DelegateMethodAttributes = MethodAttributes.Public | MethodAttributes.HideBySig | MethodAttributes.Virtual | MethodAttributes.VtableLayoutMask; var beginInvoke = new MethodDefinition("BeginInvoke", DelegateMethodAttributes, asyncResultType); foreach (var p in parameters) { beginInvoke.Parameters.Add(new ParameterDefinition(p)); } beginInvoke.Parameters.Add(new ParameterDefinition("callback", ParameterAttributes.None, asyncCallbackType)); beginInvoke.Parameters.Add(new ParameterDefinition("object", ParameterAttributes.None, objectType)); beginInvoke.ImplAttributes = MethodImplAttributes.Runtime; dt.Methods.Add(beginInvoke); // add EndInvoke var endInvoke = new MethodDefinition("EndInvoke", DelegateMethodAttributes, returnType); endInvoke.Parameters.Add(new ParameterDefinition("result", ParameterAttributes.None, asyncResultType)); endInvoke.ImplAttributes = MethodImplAttributes.Runtime; dt.Methods.Add(endInvoke); // add Invoke var invoke = new MethodDefinition("Invoke", DelegateMethodAttributes, returnType); foreach (var p in parameters) { // if(!p.IsValueType){ // invoke.Parameters.Add(new ParameterDefinition(p.Name,ParameterAttributes.In,objectType)); // }else{ invoke.Parameters.Add(new ParameterDefinition(p)); // } } invoke.ImplAttributes = MethodImplAttributes.Runtime; dt.Methods.Add(invoke); return dt; } static Type voidType = typeof(void); static Type objType = typeof(object); static OpCode[] s_ldargs = new[] { OpCodes.Ldarg_0, OpCodes.Ldarg_1, OpCodes.Ldarg_2, OpCodes.Ldarg_3 }; } }