using System; using System.Collections.Generic; using System.Linq; using Mono.Cecil; using Mono.Cecil.Cil; using Mono.Cecil.Rocks; namespace OverrideVariance { static class Program { private const string VARIANCE_ATTRIBUTE = "OverrideVarianceAttribute"; static void Main(string[] args) { string fileName = args[0]; ModuleDefinition module = ModuleDefinition.ReadModule(fileName); // find all the methods with the VARIANCE_ATTRIBUTE on the return type IEnumerable> attrMethods = FindVarianceMethods(module); // and apply the transformation foreach (var attrMethod in attrMethods) { MethodDefinition method = attrMethod.Item1; MethodReference overriddenMethod = method.GetBaseMethod(); AddVarianceMethod(method, overriddenMethod, attrMethod.Item2); RewriteExistingMethod(method); } // write out modified assembly to disk module.Write(fileName); } private static IEnumerable> FindVarianceMethods(ModuleDefinition module) { return (from method in (from type in module.GetAllTypes() from method in type.Methods where method.IsVirtual select method) let attr = method.MethodReturnType.CustomAttributes.FirstOrDefault(a => a.AttributeType.FullName == VARIANCE_ATTRIBUTE) where attr != null select Tuple.Create(method, (TypeReference)attr.ConstructorArguments[0].Value)).ToList(); } private static void AddVarianceMethod(MethodDefinition origMethod, MethodReference methodToCall, TypeReference returnType) { MethodDefinition newMethod = new MethodDefinition(origMethod.Name, origMethod.Attributes, returnType); // copy parameters across foreach (var parameter in origMethod.Parameters) { newMethod.Parameters.Add(parameter); } ILProcessor il = newMethod.Body.GetILProcessor(); // call the base virtual method which dispatches to the correct run-time implementation EmitMethodCall(il, methodToCall); if (origMethod.ReturnType.FullName != returnType.FullName) { // cast/box to the correct return type as appropriate // unbox.any works with reference & value types il.Emit(OpCodes.Unbox_Any, returnType); } il.Emit(OpCodes.Ret); newMethod.Body.MaxStackSize = 1; // add this method to the declaring class origMethod.DeclaringType.Methods.Add(newMethod); } private static void EmitMethodCall(ILProcessor il, MethodReference method) { int parameterCount = method.Parameters.Count; il.Emit(OpCodes.Ldarg_0); if (parameterCount >= 1) { il.Emit(OpCodes.Ldarg_1); } if (parameterCount >= 2) { il.Emit(OpCodes.Ldarg_2); } if (parameterCount >= 3) { il.Emit(OpCodes.Ldarg_3); } for (int i = 4; i < Math.Min(parameterCount, 256); i++) { il.Emit(OpCodes.Ldarg_S, i); } for (int i = 256; i < parameterCount; i++) { il.Emit(OpCodes.Ldarg, i); } il.Emit(OpCodes.Callvirt, method); } private static void RewriteExistingMethod(MethodDefinition method) { // change to private newslot final method.Attributes = (method.Attributes & ~MethodAttributes.MemberAccessMask) | MethodAttributes.Private | MethodAttributes.NewSlot | MethodAttributes.Final; // set the explicit override MethodReference overridden = method.GetBaseMethod(); method.Overrides.Add(overridden); // add the overridden method type to the name to prevent name clashes method.Name = overridden.DeclaringType.FullName + "." + method.Name; } } }