SickGear/lib/js2py/utils/injector.py

240 lines
8.9 KiB
Python
Raw Normal View History

__all__ = ['fix_js_args']
import types
from collections import namedtuple
import opcode
import six
import sys
import dis
if six.PY3:
xrange = range
chr = lambda x: x
# Opcode constants used for comparison and replacecment
LOAD_FAST = opcode.opmap['LOAD_FAST']
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
STORE_FAST = opcode.opmap['STORE_FAST']
def fix_js_args(func):
'''Use this function when unsure whether func takes this and arguments as its last 2 args.
It will append 2 args if it does not.'''
fcode = six.get_function_code(func)
fargs = fcode.co_varnames[fcode.co_argcount-2:fcode.co_argcount]
if fargs==('this', 'arguments') or fargs==('arguments', 'var'):
return func
code = append_arguments(six.get_function_code(func), ('this','arguments'))
return types.FunctionType(code, six.get_function_globals(func), func.__name__, closure=six.get_function_closure(func))
def append_arguments(code_obj, new_locals):
co_varnames = code_obj.co_varnames # Old locals
co_names = code_obj.co_names # Old globals
co_names+=tuple(e for e in new_locals if e not in co_names)
co_argcount = code_obj.co_argcount # Argument count
co_code = code_obj.co_code # The actual bytecode as a string
# Make one pass over the bytecode to identify names that should be
# left in code_obj.co_names.
not_removed = set(opcode.hasname) - set([LOAD_GLOBAL])
saved_names = set()
for inst in instructions(code_obj):
if inst[0] in not_removed:
saved_names.add(co_names[inst[1]])
# Build co_names for the new code object. This should consist of
# globals that were only accessed via LOAD_GLOBAL
names = tuple(name for name in co_names
if name not in set(new_locals) - saved_names)
# Build a dictionary that maps the indices of the entries in co_names
# to their entry in the new co_names
name_translations = dict((co_names.index(name), i)
for i, name in enumerate(names))
# Build co_varnames for the new code object. This should consist of
# the entirety of co_varnames with new_locals spliced in after the
# arguments
new_locals_len = len(new_locals)
varnames = (co_varnames[:co_argcount] + new_locals +
co_varnames[co_argcount:])
# Build the dictionary that maps indices of entries in the old co_varnames
# to their indices in the new co_varnames
range1, range2 = xrange(co_argcount), xrange(co_argcount, len(co_varnames))
varname_translations = dict((i, i) for i in range1)
varname_translations.update((i, i + new_locals_len) for i in range2)
# Build the dictionary that maps indices of deleted entries of co_names
# to their indices in the new co_varnames
names_to_varnames = dict((co_names.index(name), varnames.index(name))
for name in new_locals)
# Now we modify the actual bytecode
modified = []
for inst in instructions(code_obj):
op, arg = inst.opcode, inst.arg
# If the instruction is a LOAD_GLOBAL, we have to check to see if
# it's one of the globals that we are replacing. Either way,
# update its arg using the appropriate dict.
if inst.opcode == LOAD_GLOBAL:
if inst.arg in names_to_varnames:
op = LOAD_FAST
arg = names_to_varnames[inst.arg]
elif inst.arg in name_translations:
arg = name_translations[inst.arg]
else:
raise ValueError("a name was lost in translation")
# If it accesses co_varnames or co_names then update its argument.
elif inst.opcode in opcode.haslocal:
arg = varname_translations[inst.arg]
elif inst.opcode in opcode.hasname:
arg = name_translations[inst.arg]
modified.extend(write_instruction(op, arg))
if six.PY2:
code = ''.join(modified)
args = (co_argcount + new_locals_len,
code_obj.co_nlocals + new_locals_len,
code_obj.co_stacksize,
code_obj.co_flags,
code,
code_obj.co_consts,
names,
varnames,
code_obj.co_filename,
code_obj.co_name,
code_obj.co_firstlineno,
code_obj.co_lnotab,
code_obj.co_freevars,
code_obj.co_cellvars)
else:
code = bytes(modified)
args = (co_argcount + new_locals_len,
0,
code_obj.co_nlocals + new_locals_len,
code_obj.co_stacksize,
code_obj.co_flags,
code,
code_obj.co_consts,
names,
varnames,
code_obj.co_filename,
code_obj.co_name,
code_obj.co_firstlineno,
code_obj.co_lnotab,
code_obj.co_freevars,
code_obj.co_cellvars)
# Done modifying codestring - make the code object
return types.CodeType(*args)
def instructions(code_obj):
# easy for python 3.4+
if sys.version_info >= (3, 4):
for inst in dis.Bytecode(code_obj):
yield inst
else:
# otherwise we have to manually parse
code = code_obj.co_code
NewInstruction = namedtuple('Instruction', ('opcode', 'arg'))
if six.PY2:
code = map(ord, code)
i, L = 0, len(code)
extended_arg = 0
while i < L:
op = code[i]
i+= 1
if op < opcode.HAVE_ARGUMENT:
yield NewInstruction(op, None)
continue
oparg = code[i] + (code[i+1] << 8) + extended_arg
extended_arg = 0
i += 2
if op == opcode.EXTENDED_ARG:
extended_arg = oparg << 16
continue
yield NewInstruction(op, oparg)
def write_instruction(op, arg):
if sys.version_info < (3, 6):
if arg is None:
return [chr(op)]
elif arg <= 65536:
return [chr(op), chr(arg & 255), chr((arg >> 8) & 255)]
elif arg <= 4294967296:
return [chr(opcode.EXTENDED_ARG),
chr((arg >> 16) & 255),
chr((arg >> 24) & 255),
chr(op),
chr(arg & 255),
chr((arg >> 8) & 255)]
else:
raise ValueError("Invalid oparg: {0} is too large".format(oparg))
else: # python 3.6+ uses wordcode instead of bytecode and they already supply all the EXTENDEND_ARG ops :)
if arg is None:
return [chr(op), 0]
return [chr(op), arg & 255]
# the code below is for case when extended args are to be determined automatically
# if op == opcode.EXTENDED_ARG:
# return [] # this will be added automatically
# elif arg < 1 << 8:
# return [chr(op), arg]
# elif arg < 1 << 32:
# subs = [1<<24, 1<<16, 1<<8] # allowed op extension sizes
# for sub in subs:
# if arg >= sub:
# fit = int(arg / sub)
# return [chr(opcode.EXTENDED_ARG), fit] + write_instruction(op, arg - fit * sub)
# else:
# raise ValueError("Invalid oparg: {0} is too large".format(oparg))
def check(code_obj):
old_bytecode = code_obj.co_code
insts = list(instructions(code_obj))
pos_to_inst = {}
bytelist = []
for inst in insts:
pos_to_inst[len(bytelist)] = inst
bytelist.extend(write_instruction(inst.opcode, inst.arg))
if six.PY2:
new_bytecode = ''.join(bytelist)
else:
new_bytecode = bytes(bytelist)
if new_bytecode != old_bytecode:
print(new_bytecode)
print(old_bytecode)
for i in range(min(len(new_bytecode), len(old_bytecode))):
if old_bytecode[i] != new_bytecode[i]:
while 1:
if i in pos_to_inst:
print(pos_to_inst[i])
print(pos_to_inst[i-2])
print(list(map(chr, old_bytecode))[i-4:i+8])
print(bytelist[i-4:i+8])
break
raise RuntimeError('Your python version made changes to the bytecode')
check(six.get_function_code(check))
if __name__=='__main__':
x = 'Wrong'
dick = 3000
def func(a):
print(x,y,z, a)
print(dick)
d = (x,)
for e in (e for e in x):
print(e)
return x, y, z
func2 =types.FunctionType(append_arguments(six.get_function_code(func), ('x', 'y', 'z')), six.get_function_globals(func), func.__name__, closure=six.get_function_closure(func))
args = (2,2,3,4),3,4
assert func2(1, *args) == args