Coverage for C:\Users\tuhe\Documents\unitgrade\unitgrade2\unitgrade2.py : 15%
 
         
         
    Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2git add . && git commit -m "Options" && git push && pip install git+ssh://git@gitlab.compute.dtu.dk/tuhe/unitgrade.git --upgrade
4"""
5from . import cache_read
6import unittest
7import numpy as np
8import os
9import sys
10from io import StringIO
11import collections
12import inspect
13import re
14import threading
15import tqdm
16import time
17import pickle
18import itertools
20myround = lambda x: np.round(x) # required.
21msum = lambda x: sum(x)
22mfloor = lambda x: np.floor(x)
24def setup_dir_by_class(C,base_dir):
25 name = C.__class__.__name__
26 # base_dir = os.path.join(base_dir, name)
27 # if not os.path.isdir(base_dir):
28 # os.makedirs(base_dir)
29 return base_dir, name
31class Hidden:
32 def hide(self):
33 return True
35class Logger(object):
36 def __init__(self, buffer):
37 self.terminal = sys.stdout
38 self.log = buffer
40 def write(self, message):
41 self.terminal.write(message)
42 self.log.write(message)
44 def flush(self):
45 # this flush method is needed for python 3 compatibility.
46 pass
48class Capturing(list):
49 def __init__(self, *args, unmute=False, **kwargs):
50 self.unmute = unmute
51 super().__init__(*args, **kwargs)
53 def __enter__(self, capture_errors=True): # don't put arguments here.
54 self._stdout = sys.stdout
55 self._stringio = StringIO()
56 if self.unmute:
57 sys.stdout = Logger(self._stringio)
58 else:
59 sys.stdout = self._stringio
61 if capture_errors:
62 self._sterr = sys.stderr
63 sys.sterr = StringIO() # memory hole it
64 self.capture_errors = capture_errors
65 return self
67 def __exit__(self, *args):
68 self.extend(self._stringio.getvalue().splitlines())
69 del self._stringio # free up some memory
70 sys.stdout = self._stdout
71 if self.capture_errors:
72 sys.sterr = self._sterr
75class QItem(unittest.TestCase):
76 title = None
77 testfun = None
78 tol = 0
79 estimated_time = 0.42
80 _precomputed_payload = None
81 _computed_answer = None # Internal helper to later get results.
82 weight = 1 # the weight of the question.
84 def __init__(self, question=None, *args, **kwargs):
85 if self.tol > 0 and self.testfun is None:
86 self.testfun = self.assertL2Relative
87 elif self.testfun is None:
88 self.testfun = self.assertEqual
90 self.name = self.__class__.__name__
91 # self._correct_answer_payload = correct_answer_payload
92 self.question = question
94 super().__init__(*args, **kwargs)
95 if self.title is None:
96 self.title = self.name
98 def _safe_get_title(self):
99 if self._precomputed_title is not None:
100 return self._precomputed_title
101 return self.title
103 def assertNorm(self, computed, expected, tol=None):
104 if tol == None:
105 tol = self.tol
106 diff = np.abs( (np.asarray(computed).flat- np.asarray(expected)).flat )
107 nrm = np.sqrt(np.sum( diff ** 2))
109 self.error_computed = nrm
111 if nrm > tol:
112 print(f"Not equal within tolerance {tol}; norm of difference was {nrm}")
113 print(f"Element-wise differences {diff.tolist()}")
114 self.assertEqual(computed, expected, msg=f"Not equal within tolerance {tol}")
116 def assertL2(self, computed, expected, tol=None):
117 if tol == None:
118 tol = self.tol
119 diff = np.abs( (np.asarray(computed) - np.asarray(expected)) )
120 self.error_computed = np.max(diff)
122 if np.max(diff) > tol:
123 print(f"Not equal within tolerance {tol=}; deviation was {np.max(diff)=}")
124 print(f"Element-wise differences {diff.tolist()}")
125 self.assertEqual(computed, expected, msg=f"Not equal within tolerance {tol=}, {np.max(diff)=}")
127 def assertL2Relative(self, computed, expected, tol=None):
128 if tol == None:
129 tol = self.tol
130 diff = np.abs( (np.asarray(computed) - np.asarray(expected)) )
131 diff = diff / (1e-8 + np.abs( (np.asarray(computed) + np.asarray(expected)) ) )
132 self.error_computed = np.max(np.abs(diff))
133 if np.sum(diff > tol) > 0:
134 print(f"Not equal within tolerance {tol}")
135 print(f"Element-wise differences {diff.tolist()}")
136 self.assertEqual(computed, expected, msg=f"Not equal within tolerance {tol}")
138 def precomputed_payload(self):
139 return self._precomputed_payload
141 def precompute_payload(self):
142 # Pre-compute resources to include in tests (useful for getting around rng).
143 pass
145 def compute_answer(self, unmute=False):
146 raise NotImplementedError("test code here")
148 def test(self, computed, expected):
149 self.testfun(computed, expected)
151 def get_points(self, verbose=False, show_expected=False, show_computed=False,unmute=False, passall=False, silent=False, **kwargs):
152 possible = 1
153 computed = None
154 def show_computed_(computed):
155 print(">>> Your output:")
156 print(computed)
158 def show_expected_(expected):
159 print(">>> Expected output (note: may have been processed; read text script):")
160 print(expected)
162 correct = self._correct_answer_payload
163 try:
164 if unmute: # Required to not mix together print stuff.
165 print("")
166 computed = self.compute_answer(unmute=unmute)
167 except Exception as e:
168 if not passall:
169 if not silent:
170 print("\n=================================================================================")
171 print(f"When trying to run test class '{self.name}' your code threw an error:", e)
172 show_expected_(correct)
173 import traceback
174 print(traceback.format_exc())
175 print("=================================================================================")
176 return (0, possible)
178 if self._computed_answer is None:
179 self._computed_answer = computed
181 if show_expected or show_computed:
182 print("\n")
183 if show_expected:
184 show_expected_(correct)
185 if show_computed:
186 show_computed_(computed)
187 try:
188 if not passall:
189 self.test(computed=computed, expected=correct)
190 except Exception as e:
191 if not silent:
192 print("\n=================================================================================")
193 print(f"Test output from test class '{self.name}' does not match expected result. Test error:")
194 print(e)
195 show_computed_(computed)
196 show_expected_(correct)
197 return (0, possible)
198 return (1, possible)
200 def score(self):
201 try:
202 self.test()
203 except Exception as e:
204 return 0
205 return 1
207class QPrintItem(QItem):
208 def compute_answer_print(self):
209 """
210 Generate output which is to be tested. By default, both text written to the terminal using print(...) as well as return values
211 are send to process_output (see compute_answer below). In other words, the text generated is:
213 res = compute_Answer_print()
214 txt = (any terminal output generated above)
215 numbers = (any numbers found in terminal-output txt)
217 self.test(process_output(res, txt, numbers), <expected result>)
219 :return: Optional values for comparison
220 """
221 raise Exception("Generate output here. The output is passed to self.process_output")
223 def process_output(self, res, txt, numbers):
224 return res
226 def compute_answer(self, unmute=False):
227 with Capturing(unmute=unmute) as output:
228 res = self.compute_answer_print()
229 s = "\n".join(output)
230 s = rm_progress_bar(s) # Remove progress bar.
231 numbers = extract_numbers(s)
232 self._computed_answer = (res, s, numbers)
233 return self.process_output(res, s, numbers)
235class OrderedClassMembers(type):
236 @classmethod
237 def __prepare__(self, name, bases):
238 return collections.OrderedDict()
239 def __new__(self, name, bases, classdict):
240 ks = list(classdict.keys())
241 for b in bases:
242 ks += b.__ordered__
243 classdict['__ordered__'] = [key for key in ks if key not in ('__module__', '__qualname__')]
244 return type.__new__(self, name, bases, classdict)
246class QuestionGroup(metaclass=OrderedClassMembers):
247 title = "Untitled question"
248 partially_scored = False
249 t_init = 0 # Time spend on initialization (placeholder; set this externally).
250 estimated_time = 0.42
251 has_called_init_ = False
252 _name = None
253 _items = None
255 @property
256 def items(self):
257 if self._items == None:
258 self._items = []
259 members = [gt for gt in [getattr(self, gt) for gt in self.__ordered__ if gt not in ["__classcell__", "__init__"]] if inspect.isclass(gt) and issubclass(gt, QItem)]
260 for I in members:
261 self._items.append( I(question=self))
262 return self._items
264 @items.setter
265 def items(self, value):
266 self._items = value
268 @property
269 def name(self):
270 if self._name == None:
271 self._name = self.__class__.__name__
272 return self._name #
274 @name.setter
275 def name(self, val):
276 self._name = val
278 def init(self):
279 # Can be used to set resources relevant for this question instance.
280 pass
282 def init_all_item_questions(self):
283 for item in self.items:
284 if not item.question.has_called_init_:
285 item.question.init()
286 item.question.has_called_init_ = True
289class Report():
290 title = "report title"
291 version = None
292 questions = []
293 pack_imports = []
294 individual_imports = []
296 @classmethod
297 def reset(cls):
298 for (q,_) in cls.questions:
299 if hasattr(q, 'reset'):
300 q.reset()
302 def _file(self):
303 return inspect.getfile(type(self))
305 def __init__(self, strict=False, payload=None):
306 working_directory = os.path.abspath(os.path.dirname(self._file()))
308 self.wdir, self.name = setup_dir_by_class(self, working_directory)
309 # self.computed_answers_file = os.path.join(self.wdir, self.name + "_resources_do_not_hand_in.dat")
311 if payload is not None:
312 self.set_payload(payload, strict=strict)
313 # else:
314 # if os.path.isfile(self.computed_answers_file):
315 # self.set_payload(cache_read(self.computed_answers_file), strict=strict)
316 # else:
317 # s = f"> Warning: The pre-computed answer file, {os.path.abspath(self.computed_answers_file)} is missing. The framework will NOT work as intended. Reasons may be a broken local installation."
318 # if strict:
319 # raise Exception(s)
320 # else:
321 # print(s)
323 def main(self, verbosity=1):
324 # Run all tests using standard unittest (nothing fancy).
325 import unittest
326 loader = unittest.TestLoader()
327 for q,_ in self.questions:
328 import time
329 start = time.time() # A good proxy for setup time is to
330 suite = loader.loadTestsFromTestCase(q)
331 unittest.TextTestRunner(verbosity=verbosity).run(suite)
332 total = time.time() - start
333 q.time = total
335 def _setup_answers(self):
336 self.main() # Run all tests in class just to get that out of the way...
337 report_cache = {}
338 for q, _ in self.questions:
339 if hasattr(q, '_save_cache'):
340 q()._save_cache()
341 q._cache['time'] = q.time
342 report_cache[q.__qualname__] = q._cache
343 else:
344 report_cache[q.__qualname__] = {'no cache see _setup_answers in unitgrade2.py':True}
345 return report_cache
347 def set_payload(self, payloads, strict=False):
348 for q, _ in self.questions:
349 q._cache = payloads[q.__qualname__]
351 # for item in q.items:
352 # if q.name not in payloads or item.name not in payloads[q.name]:
353 # s = f"> Broken resource dictionary submitted to unitgrade for question {q.name} and subquestion {item.name}. Framework will not work."
354 # if strict:
355 # raise Exception(s)
356 # else:
357 # print(s)
358 # else:
359 # item._correct_answer_payload = payloads[q.name][item.name]['payload']
360 # item.estimated_time = payloads[q.name][item.name].get("time", 1)
361 # q.estimated_time = payloads[q.name].get("time", 1)
362 # if "precomputed" in payloads[q.name][item.name]: # Consider removing later.
363 # item._precomputed_payload = payloads[q.name][item.name]['precomputed']
364 # try:
365 # if "title" in payloads[q.name][item.name]: # can perhaps be removed later.
366 # item.title = payloads[q.name][item.name]['title']
367 # except Exception as e: # Cannot set attribute error. The title is a function (and probably should not be).
368 # pass
369 # # print("bad", e)
370 # self.payloads = payloads
373def rm_progress_bar(txt):
374 # More robust version. Apparently length of bar can depend on various factors, so check for order of symbols.
375 nlines = []
376 for l in txt.splitlines():
377 pct = l.find("%")
378 ql = False
379 if pct > 0:
380 i = l.find("|", pct+1)
381 if i > 0 and l.find("|", i+1) > 0:
382 ql = True
383 if not ql:
384 nlines.append(l)
385 return "\n".join(nlines)
387def extract_numbers(txt):
388 # txt = rm_progress_bar(txt)
389 numeric_const_pattern = '[-+]? (?: (?: \d* \. \d+ ) | (?: \d+ \.? ) )(?: [Ee] [+-]? \d+ ) ?'
390 rx = re.compile(numeric_const_pattern, re.VERBOSE)
391 all = rx.findall(txt)
392 all = [float(a) if ('.' in a or "e" in a) else int(a) for a in all]
393 if len(all) > 500:
394 print(txt)
395 raise Exception("unitgrade.unitgrade.py: Warning, many numbers!", len(all))
396 return all
399class ActiveProgress():
400 def __init__(self, t, start=True, title="my progress bar"):
401 self.t = t
402 self._running = False
403 self.title = title
404 self.dt = 0.1
405 self.n = int(np.round(self.t / self.dt))
406 # self.pbar = tqdm.tqdm(total=self.n)
407 if start:
408 self.start()
410 def start(self):
411 self._running = True
412 self.thread = threading.Thread(target=self.run)
413 self.thread.start()
414 self.time_started = time.time()
416 def terminate(self):
417 if not self._running:
418 raise Exception("Stopping a stopped progress bar. ")
419 self._running = False
420 self.thread.join()
421 if hasattr(self, 'pbar') and self.pbar is not None:
422 self.pbar.update(1)
423 self.pbar.close()
424 self.pbar=None
426 sys.stdout.flush()
427 return time.time() - self.time_started
429 def run(self):
430 self.pbar = tqdm.tqdm(total=self.n, file=sys.stdout, position=0, leave=False, desc=self.title, ncols=100,
431 bar_format='{l_bar}{bar}| [{elapsed}<{remaining}]') # , unit_scale=dt, unit='seconds'):
433 for _ in range(self.n-1): # Don't terminate completely; leave bar at 99% done until terminate.
434 if not self._running:
435 self.pbar.close()
436 self.pbar = None
437 break
439 time.sleep(self.dt)
440 self.pbar.update(1)
444from unittest.suite import _isnotsuite
446class MySuite(unittest.suite.TestSuite): # Not sure we need this one anymore.
447 pass
449def instance_call_stack(instance):
450 s = "-".join(map(lambda x: x.__name__, instance.__class__.mro()))
451 return s
453def get_class_that_defined_method(meth):
454 for cls in inspect.getmro(meth.im_class):
455 if meth.__name__ in cls.__dict__:
456 return cls
457 return None
459def caller_name(skip=2):
460 """Get a name of a caller in the format module.class.method
462 `skip` specifies how many levels of stack to skip while getting caller
463 name. skip=1 means "who calls me", skip=2 "who calls my caller" etc.
465 An empty string is returned if skipped levels exceed stack height
466 """
467 stack = inspect.stack()
468 start = 0 + skip
469 if len(stack) < start + 1:
470 return ''
471 parentframe = stack[start][0]
473 name = []
474 module = inspect.getmodule(parentframe)
475 # `modname` can be None when frame is executed directly in console
476 # TODO(techtonik): consider using __main__
477 if module:
478 name.append(module.__name__)
479 # detect classname
480 if 'self' in parentframe.f_locals:
481 # I don't know any way to detect call from the object method
482 # XXX: there seems to be no way to detect static method call - it will
483 # be just a function call
484 name.append(parentframe.f_locals['self'].__class__.__name__)
485 codename = parentframe.f_code.co_name
486 if codename != '<module>': # top level usually
487 name.append( codename ) # function or a method
489 ## Avoid circular refs and frame leaks
490 # https://docs.python.org/2.7/library/inspect.html#the-interpreter-stack
491 del parentframe, stack
493 return ".".join(name)
495def get_class_from_frame(fr):
496 import inspect
497 args, _, _, value_dict = inspect.getargvalues(fr)
498 # we check the first parameter for the frame function is
499 # named 'self'
500 if len(args) and args[0] == 'self':
501 # in that case, 'self' will be referenced in value_dict
502 instance = value_dict.get('self', None)
503 if instance:
504 # return its class
505 # isinstance(instance, Testing) # is the actual class instance.
507 return getattr(instance, '__class__', None)
508 # return None otherwise
509 return None
511from typing import Any
512import inspect, gc
514def giveupthefunc():
515 frame = inspect.currentframe()
516 code = frame.f_code
517 globs = frame.f_globals
518 functype = type(lambda: 0)
519 funcs = []
520 for func in gc.get_referrers(code):
521 if type(func) is functype:
522 if getattr(func, "__code__", None) is code:
523 if getattr(func, "__globals__", None) is globs:
524 funcs.append(func)
525 if len(funcs) > 1:
526 return None
527 return funcs[0] if funcs else None
530from collections import defaultdict
532class UTextResult(unittest.TextTestResult):
533 nL = 80
534 show_progress_bar = True
535 def __init__(self, stream, descriptions, verbosity):
536 super().__init__(stream, descriptions, verbosity)
537 self.successes = []
539 def printErrors(self) -> None:
540 # if self.dots or self.showAll:
541 # self.stream.writeln()
542 # if hasattr(self, 'cc'):
543 # self.cc.terminate()
544 # self.cc_terminate(success=False)
545 self.printErrorList('ERROR', self.errors)
546 self.printErrorList('FAIL', self.failures)
548 def addError(self, test, err):
549 super(unittest.TextTestResult, self).addFailure(test, err)
550 self.cc_terminate(success=False)
552 def addFailure(self, test, err):
553 super(unittest.TextTestResult, self).addFailure(test, err)
554 self.cc_terminate(success=False)
555 # if self.showAll:
556 # self.stream.writeln("FAIL")
557 # elif self.dots:
558 # self.stream.write('F')
559 # self.stream.flush()
561 def addSuccess(self, test: unittest.case.TestCase) -> None:
562 # super().addSuccess(test)
563 self.successes.append(test)
564 # super().addSuccess(test)
565 # hidden = issubclass(item.__class__, Hidden)
566 # # if not hidden:
567 # # print(ss, end="")
568 # # sys.stdout.flush()
569 # start = time.time()
570 #
571 # (current, possible) = item.get_points(show_expected=show_expected, show_computed=show_computed,unmute=unmute, passall=passall, silent=silent)
572 # q_[j] = {'w': item.weight, 'possible': possible, 'obtained': current, 'hidden': hidden, 'computed': str(item._computed_answer), 'title': item.title}
573 # tsecs = np.round(time.time()-start, 2)
574 self.cc_terminate()
578 def cc_terminate(self, success=True):
579 if self.show_progress_bar:
580 tsecs = np.round(self.cc.terminate(), 2)
581 sys.stdout.flush()
582 ss = self.item_title_print
583 print(self.item_title_print + ('.' * max(0, self.nL - 4 - len(ss))), end="")
584 current = 1
585 possible = 1
586 # current == possible
587 ss = "PASS" if success else "FAILED"
588 if tsecs >= 0.1:
589 ss += " (" + str(tsecs) + " seconds)"
590 print(ss)
593 def startTest(self, test):
594 # super().startTest(test)
595 self.testsRun += 1
596 # print("Starting the test...")
597 # show_progress_bar = True
598 n = 1
599 j = 1
600 item_title = self.getDescription(test)
601 item_title = item_title.split("\n")[0]
603 # test.countTestCases()
605 self.item_title_print = "*** q%i.%i) %s" % (n + 1, j + 1, item_title)
606 estimated_time = 10
607 nL = 80
608 #
609 if self.show_progress_bar:
610 self.cc = ActiveProgress(t=estimated_time, title=self.item_title_print)
611 else:
612 print(self.item_title_print + ('.' * max(0, nL - 4 - len(self.item_title_print))), end="")
614 self._test = test
616 def _setupStdout(self):
617 if self._previousTestClass == None:
618 total_estimated_time = 2
619 if hasattr(self.__class__, 'q_title_print'):
620 q_title_print = self.__class__.q_title_print
621 else:
622 q_title_print = "<unnamed test. See unitgrade.py>"
624 # q_title_print = "some printed title..."
625 cc = ActiveProgress(t=total_estimated_time, title=q_title_print)
626 self.cc = cc
628 def _restoreStdout(self): # Used when setting up the test.
629 if self._previousTestClass == None:
630 q_time = self.cc.terminate()
631 q_time = np.round(q_time, 2)
632 sys.stdout.flush()
633 print(self.cc.title, end="")
634 # start = 10
635 # q_time = np.round(time.time() - start, 2)
636 nL = 80
637 print(" " * max(0, nL - len(self.cc.title)) + (
638 " (" + str(q_time) + " seconds)" if q_time >= 0.1 else "")) # if q.name in report.payloads else "")
639 print("=" * nL)
641from unittest.runner import _WritelnDecorator
642from io import StringIO
644class UTextTestRunner(unittest.TextTestRunner):
645 def __init__(self, *args, **kwargs):
646 from io import StringIO
647 stream = StringIO()
648 super().__init__(*args, stream=stream, **kwargs)
650 def _makeResult(self):
651 # stream = self.stream # not you!
652 stream = sys.stdout
653 stream = _WritelnDecorator(stream)
654 return self.resultclass(stream, self.descriptions, self.verbosity)
656def wrapper(foo):
657 def magic(self):
658 s = "-".join(map(lambda x: x.__name__, self.__class__.mro()))
659 # print(s)
660 foo(self)
661 magic.__doc__ = foo.__doc__
662 return magic
664from functools import update_wrapper, _make_key, RLock
665from collections import namedtuple
666_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
668def cache(foo, typed=False):
669 """ Magic cache wrapper
670 https://github.com/python/cpython/blob/main/Lib/functools.py
671 """
672 maxsize = None
673 def wrapper(self, *args, **kwargs):
674 key = self.cache_id() + ("cache", _make_key(args, kwargs, typed))
675 if not self._cache_contains(key):
676 value = foo(self, *args, **kwargs)
677 self._cache_put(key, value)
678 else:
679 value = self._cache_get(key)
680 return value
681 return wrapper
684class UTestCase(unittest.TestCase):
685 _outcome = None # A dictionary which stores the user-computed outcomes of all the tests. This differs from the cache.
686 _cache = None # Read-only cache.
687 _cache2 = None # User-written cache
689 @classmethod
690 def reset(cls):
691 cls._outcome = None
692 cls._cache = None
693 cls._cache2 = None
695 def _get_outcome(self):
696 if not (self.__class__, '_outcome') or self.__class__._outcome == None:
697 self.__class__._outcome = {}
698 return self.__class__._outcome
700 def _callTestMethod(self, testMethod):
701 t = time.time()
702 res = testMethod()
703 elapsed = time.time() - t
704 # if res == None:
705 # res = {}
706 # res['time'] = elapsed
707 sd = self.shortDescription()
708 self._cache_put( (self.cache_id(), 'title'), self._testMethodName if sd == None else sd)
709 # self._test_fun_output = res
710 self._get_outcome()[self.cache_id()] = res
711 self._cache_put( (self.cache_id(), "time"), elapsed)
714 # This is my base test class. So what is new about it?
715 def cache_id(self):
716 c = self.__class__.__qualname__
717 m = self._testMethodName
718 return (c,m)
720 def unique_cache_id(self):
721 k0 = self.cache_id()
722 key = ()
723 for i in itertools.count():
724 key = k0 + (i,)
725 if not self._cache2_contains(key):
726 break
727 return key
729 def __init__(self, *args, **kwargs):
730 super().__init__(*args, **kwargs)
731 self._load_cache()
732 self.cache_indexes = defaultdict(lambda: 0)
734 def _ensure_cache_exists(self):
735 if not hasattr(self.__class__, '_cache') or self.__class__._cache == None:
736 self.__class__._cache = dict()
737 if not hasattr(self.__class__, '_cache2') or self.__class__._cache2 == None:
738 self.__class__._cache2 = dict()
740 def _cache_get(self, key, default=None):
741 self._ensure_cache_exists()
742 return self.__class__._cache.get(key, default)
744 def _cache_put(self, key, value):
745 self._ensure_cache_exists()
746 self.__class__._cache2[key] = value
748 def _cache_contains(self, key):
749 self._ensure_cache_exists()
750 return key in self.__class__._cache
752 def _cache2_contains(self, key):
753 self._ensure_cache_exists()
754 return key in self.__class__._cache2
756 def assertEqualC(self, first: Any, msg: Any = ...) -> None:
757 id = self.unique_cache_id()
758 if not self._cache_contains(id):
759 print("Warning, framework missing key", id)
761 self.assertEqual(first, self._cache_get(id, first), msg)
762 self._cache_put(id, first)
764 def _cache_file(self):
765 return os.path.dirname(inspect.getfile(self.__class__) ) + "/unitgrade/" + self.__class__.__name__ + ".pkl"
767 def _save_cache(self):
768 # get the class name (i.e. what to save to).
769 cfile = self._cache_file()
770 if not os.path.isdir(os.path.dirname(cfile)):
771 os.makedirs(os.path.dirname(cfile))
773 if hasattr(self.__class__, '_cache2'):
774 with open(cfile, 'wb') as f:
775 pickle.dump(self.__class__._cache2, f)
777 # But you can also set cache explicitly.
778 def _load_cache(self):
779 if self._cache != None: # Cache already loaded. We will not load it twice.
780 return
781 # raise Exception("Loaded cache which was already set. What is going on?!")
782 cfile = self._cache_file()
783 print("Loading cache from", cfile)
784 if os.path.exists(cfile):
785 with open(cfile, 'rb') as f:
786 data = pickle.load(f)
787 self.__class__._cache = data
788 else:
789 print("Warning! data file not found", cfile)
791def hide(func):
792 return func
794def makeRegisteringDecorator(foreignDecorator):
795 """
796 Returns a copy of foreignDecorator, which is identical in every
797 way(*), except also appends a .decorator property to the callable it
798 spits out.
799 """
800 def newDecorator(func):
801 # Call to newDecorator(method)
802 # Exactly like old decorator, but output keeps track of what decorated it
803 R = foreignDecorator(func) # apply foreignDecorator, like call to foreignDecorator(method) would have done
804 R.decorator = newDecorator # keep track of decorator
805 # R.original = func # might as well keep track of everything!
806 return R
808 newDecorator.__name__ = foreignDecorator.__name__
809 newDecorator.__doc__ = foreignDecorator.__doc__
810 # (*)We can be somewhat "hygienic", but newDecorator still isn't signature-preserving, i.e. you will not be able to get a runtime list of parameters. For that, you need hackish libraries...but in this case, the only argument is func, so it's not a big issue
811 return newDecorator
813hide = makeRegisteringDecorator(hide)
815def methodsWithDecorator(cls, decorator):
816 """
817 Returns all methods in CLS with DECORATOR as the
818 outermost decorator.
820 DECORATOR must be a "registering decorator"; one
821 can make any decorator "registering" via the
822 makeRegisteringDecorator function.
824 import inspect
825 ls = list(methodsWithDecorator(GeneratorQuestion, deco))
826 for f in ls:
827 print(inspect.getsourcelines(f) ) # How to get all hidden questions.
828 """
829 for maybeDecorated in cls.__dict__.values():
830 if hasattr(maybeDecorated, 'decorator'):
831 if maybeDecorated.decorator == decorator:
832 print(maybeDecorated)
833 yield maybeDecorated