__all__ = ['fix_js_args'] import types from collections import namedtuple import opcode import six import sys import dis if six.PY3: xrange = range chr = lambda x: x # Opcode constants used for comparison and replacecment LOAD_FAST = opcode.opmap['LOAD_FAST'] LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL'] STORE_FAST = opcode.opmap['STORE_FAST'] def fix_js_args(func): '''Use this function when unsure whether func takes this and arguments as its last 2 args. It will append 2 args if it does not.''' fcode = six.get_function_code(func) fargs = fcode.co_varnames[fcode.co_argcount-2:fcode.co_argcount] if fargs==('this', 'arguments') or fargs==('arguments', 'var'): return func code = append_arguments(six.get_function_code(func), ('this','arguments')) return types.FunctionType(code, six.get_function_globals(func), func.__name__, closure=six.get_function_closure(func)) def append_arguments(code_obj, new_locals): co_varnames = code_obj.co_varnames # Old locals co_names = code_obj.co_names # Old globals co_names+=tuple(e for e in new_locals if e not in co_names) co_argcount = code_obj.co_argcount # Argument count co_code = code_obj.co_code # The actual bytecode as a string # Make one pass over the bytecode to identify names that should be # left in code_obj.co_names. not_removed = set(opcode.hasname) - set([LOAD_GLOBAL]) saved_names = set() for inst in instructions(code_obj): if inst[0] in not_removed: saved_names.add(co_names[inst[1]]) # Build co_names for the new code object. This should consist of # globals that were only accessed via LOAD_GLOBAL names = tuple(name for name in co_names if name not in set(new_locals) - saved_names) # Build a dictionary that maps the indices of the entries in co_names # to their entry in the new co_names name_translations = dict((co_names.index(name), i) for i, name in enumerate(names)) # Build co_varnames for the new code object. This should consist of # the entirety of co_varnames with new_locals spliced in after the # arguments new_locals_len = len(new_locals) varnames = (co_varnames[:co_argcount] + new_locals + co_varnames[co_argcount:]) # Build the dictionary that maps indices of entries in the old co_varnames # to their indices in the new co_varnames range1, range2 = xrange(co_argcount), xrange(co_argcount, len(co_varnames)) varname_translations = dict((i, i) for i in range1) varname_translations.update((i, i + new_locals_len) for i in range2) # Build the dictionary that maps indices of deleted entries of co_names # to their indices in the new co_varnames names_to_varnames = dict((co_names.index(name), varnames.index(name)) for name in new_locals) # Now we modify the actual bytecode modified = [] for inst in instructions(code_obj): op, arg = inst.opcode, inst.arg # If the instruction is a LOAD_GLOBAL, we have to check to see if # it's one of the globals that we are replacing. Either way, # update its arg using the appropriate dict. if inst.opcode == LOAD_GLOBAL: if inst.arg in names_to_varnames: op = LOAD_FAST arg = names_to_varnames[inst.arg] elif inst.arg in name_translations: arg = name_translations[inst.arg] else: raise ValueError("a name was lost in translation") # If it accesses co_varnames or co_names then update its argument. elif inst.opcode in opcode.haslocal: arg = varname_translations[inst.arg] elif inst.opcode in opcode.hasname: arg = name_translations[inst.arg] modified.extend(write_instruction(op, arg)) if six.PY2: code = ''.join(modified) args = (co_argcount + new_locals_len, code_obj.co_nlocals + new_locals_len, code_obj.co_stacksize, code_obj.co_flags, code, code_obj.co_consts, names, varnames, code_obj.co_filename, code_obj.co_name, code_obj.co_firstlineno, code_obj.co_lnotab, code_obj.co_freevars, code_obj.co_cellvars) else: code = bytes(modified) args = (co_argcount + new_locals_len, 0, code_obj.co_nlocals + new_locals_len, code_obj.co_stacksize, code_obj.co_flags, code, code_obj.co_consts, names, varnames, code_obj.co_filename, code_obj.co_name, code_obj.co_firstlineno, code_obj.co_lnotab, code_obj.co_freevars, code_obj.co_cellvars) # Done modifying codestring - make the code object return types.CodeType(*args) def instructions(code_obj): # easy for python 3.4+ if sys.version_info >= (3, 4): for inst in dis.Bytecode(code_obj): yield inst else: # otherwise we have to manually parse code = code_obj.co_code NewInstruction = namedtuple('Instruction', ('opcode', 'arg')) if six.PY2: code = map(ord, code) i, L = 0, len(code) extended_arg = 0 while i < L: op = code[i] i+= 1 if op < opcode.HAVE_ARGUMENT: yield NewInstruction(op, None) continue oparg = code[i] + (code[i+1] << 8) + extended_arg extended_arg = 0 i += 2 if op == opcode.EXTENDED_ARG: extended_arg = oparg << 16 continue yield NewInstruction(op, oparg) def write_instruction(op, arg): if sys.version_info < (3, 6): if arg is None: return [chr(op)] elif arg <= 65536: return [chr(op), chr(arg & 255), chr((arg >> 8) & 255)] elif arg <= 4294967296: return [chr(opcode.EXTENDED_ARG), chr((arg >> 16) & 255), chr((arg >> 24) & 255), chr(op), chr(arg & 255), chr((arg >> 8) & 255)] else: raise ValueError("Invalid oparg: {0} is too large".format(oparg)) else: # python 3.6+ uses wordcode instead of bytecode and they already supply all the EXTENDEND_ARG ops :) if arg is None: return [chr(op), 0] return [chr(op), arg & 255] # the code below is for case when extended args are to be determined automatically # if op == opcode.EXTENDED_ARG: # return [] # this will be added automatically # elif arg < 1 << 8: # return [chr(op), arg] # elif arg < 1 << 32: # subs = [1<<24, 1<<16, 1<<8] # allowed op extension sizes # for sub in subs: # if arg >= sub: # fit = int(arg / sub) # return [chr(opcode.EXTENDED_ARG), fit] + write_instruction(op, arg - fit * sub) # else: # raise ValueError("Invalid oparg: {0} is too large".format(oparg)) def check(code_obj): old_bytecode = code_obj.co_code insts = list(instructions(code_obj)) pos_to_inst = {} bytelist = [] for inst in insts: pos_to_inst[len(bytelist)] = inst bytelist.extend(write_instruction(inst.opcode, inst.arg)) if six.PY2: new_bytecode = ''.join(bytelist) else: new_bytecode = bytes(bytelist) if new_bytecode != old_bytecode: print(new_bytecode) print(old_bytecode) for i in range(min(len(new_bytecode), len(old_bytecode))): if old_bytecode[i] != new_bytecode[i]: while 1: if i in pos_to_inst: print(pos_to_inst[i]) print(pos_to_inst[i-2]) print(list(map(chr, old_bytecode))[i-4:i+8]) print(bytelist[i-4:i+8]) break raise RuntimeError('Your python version made changes to the bytecode') check(six.get_function_code(check)) if __name__=='__main__': x = 'Wrong' dick = 3000 def func(a): print(x,y,z, a) print(dick) d = (x,) for e in (e for e in x): print(e) return x, y, z func2 =types.FunctionType(append_arguments(six.get_function_code(func), ('x', 'y', 'z')), six.get_function_globals(func), func.__name__, closure=six.get_function_closure(func)) args = (2,2,3,4),3,4 assert func2(1, *args) == args