Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2git add . && git commit -m "Options" && git push && pip install git+ssh://git@gitlab.compute.dtu.dk/tuhe/unitgrade.git --upgrade 

3 

4""" 

5from . import cache_read 

6import unittest 

7import numpy as np 

8import os 

9import sys 

10from io import StringIO 

11import collections 

12import inspect 

13import re 

14import threading 

15import tqdm 

16import time 

17import pickle 

18import itertools 

19 

20myround = lambda x: np.round(x) # required. 

21msum = lambda x: sum(x) 

22mfloor = lambda x: np.floor(x) 

23 

24def setup_dir_by_class(C,base_dir): 

25 name = C.__class__.__name__ 

26 # base_dir = os.path.join(base_dir, name) 

27 # if not os.path.isdir(base_dir): 

28 # os.makedirs(base_dir) 

29 return base_dir, name 

30 

31class Hidden: 

32 def hide(self): 

33 return True 

34 

35class Logger(object): 

36 def __init__(self, buffer): 

37 self.terminal = sys.stdout 

38 self.log = buffer 

39 

40 def write(self, message): 

41 self.terminal.write(message) 

42 self.log.write(message) 

43 

44 def flush(self): 

45 # this flush method is needed for python 3 compatibility. 

46 pass 

47 

48class Capturing(list): 

49 def __init__(self, *args, unmute=False, **kwargs): 

50 self.unmute = unmute 

51 super().__init__(*args, **kwargs) 

52 

53 def __enter__(self, capture_errors=True): # don't put arguments here. 

54 self._stdout = sys.stdout 

55 self._stringio = StringIO() 

56 if self.unmute: 

57 sys.stdout = Logger(self._stringio) 

58 else: 

59 sys.stdout = self._stringio 

60 

61 if capture_errors: 

62 self._sterr = sys.stderr 

63 sys.sterr = StringIO() # memory hole it 

64 self.capture_errors = capture_errors 

65 return self 

66 

67 def __exit__(self, *args): 

68 self.extend(self._stringio.getvalue().splitlines()) 

69 del self._stringio # free up some memory 

70 sys.stdout = self._stdout 

71 if self.capture_errors: 

72 sys.sterr = self._sterr 

73 

74 

75class QItem(unittest.TestCase): 

76 title = None 

77 testfun = None 

78 tol = 0 

79 estimated_time = 0.42 

80 _precomputed_payload = None 

81 _computed_answer = None # Internal helper to later get results. 

82 weight = 1 # the weight of the question. 

83 

84 def __init__(self, question=None, *args, **kwargs): 

85 if self.tol > 0 and self.testfun is None: 

86 self.testfun = self.assertL2Relative 

87 elif self.testfun is None: 

88 self.testfun = self.assertEqual 

89 

90 self.name = self.__class__.__name__ 

91 # self._correct_answer_payload = correct_answer_payload 

92 self.question = question 

93 

94 super().__init__(*args, **kwargs) 

95 if self.title is None: 

96 self.title = self.name 

97 

98 def _safe_get_title(self): 

99 if self._precomputed_title is not None: 

100 return self._precomputed_title 

101 return self.title 

102 

103 def assertNorm(self, computed, expected, tol=None): 

104 if tol == None: 

105 tol = self.tol 

106 diff = np.abs( (np.asarray(computed).flat- np.asarray(expected)).flat ) 

107 nrm = np.sqrt(np.sum( diff ** 2)) 

108 

109 self.error_computed = nrm 

110 

111 if nrm > tol: 

112 print(f"Not equal within tolerance {tol}; norm of difference was {nrm}") 

113 print(f"Element-wise differences {diff.tolist()}") 

114 self.assertEqual(computed, expected, msg=f"Not equal within tolerance {tol}") 

115 

116 def assertL2(self, computed, expected, tol=None): 

117 if tol == None: 

118 tol = self.tol 

119 diff = np.abs( (np.asarray(computed) - np.asarray(expected)) ) 

120 self.error_computed = np.max(diff) 

121 

122 if np.max(diff) > tol: 

123 print(f"Not equal within tolerance {tol=}; deviation was {np.max(diff)=}") 

124 print(f"Element-wise differences {diff.tolist()}") 

125 self.assertEqual(computed, expected, msg=f"Not equal within tolerance {tol=}, {np.max(diff)=}") 

126 

127 def assertL2Relative(self, computed, expected, tol=None): 

128 if tol == None: 

129 tol = self.tol 

130 diff = np.abs( (np.asarray(computed) - np.asarray(expected)) ) 

131 diff = diff / (1e-8 + np.abs( (np.asarray(computed) + np.asarray(expected)) ) ) 

132 self.error_computed = np.max(np.abs(diff)) 

133 if np.sum(diff > tol) > 0: 

134 print(f"Not equal within tolerance {tol}") 

135 print(f"Element-wise differences {diff.tolist()}") 

136 self.assertEqual(computed, expected, msg=f"Not equal within tolerance {tol}") 

137 

138 def precomputed_payload(self): 

139 return self._precomputed_payload 

140 

141 def precompute_payload(self): 

142 # Pre-compute resources to include in tests (useful for getting around rng). 

143 pass 

144 

145 def compute_answer(self, unmute=False): 

146 raise NotImplementedError("test code here") 

147 

148 def test(self, computed, expected): 

149 self.testfun(computed, expected) 

150 

151 def get_points(self, verbose=False, show_expected=False, show_computed=False,unmute=False, passall=False, silent=False, **kwargs): 

152 possible = 1 

153 computed = None 

154 def show_computed_(computed): 

155 print(">>> Your output:") 

156 print(computed) 

157 

158 def show_expected_(expected): 

159 print(">>> Expected output (note: may have been processed; read text script):") 

160 print(expected) 

161 

162 correct = self._correct_answer_payload 

163 try: 

164 if unmute: # Required to not mix together print stuff. 

165 print("") 

166 computed = self.compute_answer(unmute=unmute) 

167 except Exception as e: 

168 if not passall: 

169 if not silent: 

170 print("\n=================================================================================") 

171 print(f"When trying to run test class '{self.name}' your code threw an error:", e) 

172 show_expected_(correct) 

173 import traceback 

174 print(traceback.format_exc()) 

175 print("=================================================================================") 

176 return (0, possible) 

177 

178 if self._computed_answer is None: 

179 self._computed_answer = computed 

180 

181 if show_expected or show_computed: 

182 print("\n") 

183 if show_expected: 

184 show_expected_(correct) 

185 if show_computed: 

186 show_computed_(computed) 

187 try: 

188 if not passall: 

189 self.test(computed=computed, expected=correct) 

190 except Exception as e: 

191 if not silent: 

192 print("\n=================================================================================") 

193 print(f"Test output from test class '{self.name}' does not match expected result. Test error:") 

194 print(e) 

195 show_computed_(computed) 

196 show_expected_(correct) 

197 return (0, possible) 

198 return (1, possible) 

199 

200 def score(self): 

201 try: 

202 self.test() 

203 except Exception as e: 

204 return 0 

205 return 1 

206 

207class QPrintItem(QItem): 

208 def compute_answer_print(self): 

209 """ 

210 Generate output which is to be tested. By default, both text written to the terminal using print(...) as well as return values 

211 are send to process_output (see compute_answer below). In other words, the text generated is: 

212 

213 res = compute_Answer_print() 

214 txt = (any terminal output generated above) 

215 numbers = (any numbers found in terminal-output txt) 

216 

217 self.test(process_output(res, txt, numbers), <expected result>) 

218 

219 :return: Optional values for comparison 

220 """ 

221 raise Exception("Generate output here. The output is passed to self.process_output") 

222 

223 def process_output(self, res, txt, numbers): 

224 return res 

225 

226 def compute_answer(self, unmute=False): 

227 with Capturing(unmute=unmute) as output: 

228 res = self.compute_answer_print() 

229 s = "\n".join(output) 

230 s = rm_progress_bar(s) # Remove progress bar. 

231 numbers = extract_numbers(s) 

232 self._computed_answer = (res, s, numbers) 

233 return self.process_output(res, s, numbers) 

234 

235class OrderedClassMembers(type): 

236 @classmethod 

237 def __prepare__(self, name, bases): 

238 return collections.OrderedDict() 

239 def __new__(self, name, bases, classdict): 

240 ks = list(classdict.keys()) 

241 for b in bases: 

242 ks += b.__ordered__ 

243 classdict['__ordered__'] = [key for key in ks if key not in ('__module__', '__qualname__')] 

244 return type.__new__(self, name, bases, classdict) 

245 

246class QuestionGroup(metaclass=OrderedClassMembers): 

247 title = "Untitled question" 

248 partially_scored = False 

249 t_init = 0 # Time spend on initialization (placeholder; set this externally). 

250 estimated_time = 0.42 

251 has_called_init_ = False 

252 _name = None 

253 _items = None 

254 

255 @property 

256 def items(self): 

257 if self._items == None: 

258 self._items = [] 

259 members = [gt for gt in [getattr(self, gt) for gt in self.__ordered__ if gt not in ["__classcell__", "__init__"]] if inspect.isclass(gt) and issubclass(gt, QItem)] 

260 for I in members: 

261 self._items.append( I(question=self)) 

262 return self._items 

263 

264 @items.setter 

265 def items(self, value): 

266 self._items = value 

267 

268 @property 

269 def name(self): 

270 if self._name == None: 

271 self._name = self.__class__.__name__ 

272 return self._name # 

273 

274 @name.setter 

275 def name(self, val): 

276 self._name = val 

277 

278 def init(self): 

279 # Can be used to set resources relevant for this question instance. 

280 pass 

281 

282 def init_all_item_questions(self): 

283 for item in self.items: 

284 if not item.question.has_called_init_: 

285 item.question.init() 

286 item.question.has_called_init_ = True 

287 

288 

289class Report(): 

290 title = "report title" 

291 version = None 

292 questions = [] 

293 pack_imports = [] 

294 individual_imports = [] 

295 

296 @classmethod 

297 def reset(cls): 

298 for (q,_) in cls.questions: 

299 if hasattr(q, 'reset'): 

300 q.reset() 

301 

302 def _file(self): 

303 return inspect.getfile(type(self)) 

304 

305 def __init__(self, strict=False, payload=None): 

306 working_directory = os.path.abspath(os.path.dirname(self._file())) 

307 

308 self.wdir, self.name = setup_dir_by_class(self, working_directory) 

309 # self.computed_answers_file = os.path.join(self.wdir, self.name + "_resources_do_not_hand_in.dat") 

310 

311 if payload is not None: 

312 self.set_payload(payload, strict=strict) 

313 # else: 

314 # if os.path.isfile(self.computed_answers_file): 

315 # self.set_payload(cache_read(self.computed_answers_file), strict=strict) 

316 # else: 

317 # s = f"> Warning: The pre-computed answer file, {os.path.abspath(self.computed_answers_file)} is missing. The framework will NOT work as intended. Reasons may be a broken local installation." 

318 # if strict: 

319 # raise Exception(s) 

320 # else: 

321 # print(s) 

322 

323 def main(self, verbosity=1): 

324 # Run all tests using standard unittest (nothing fancy). 

325 import unittest 

326 loader = unittest.TestLoader() 

327 for q,_ in self.questions: 

328 import time 

329 start = time.time() # A good proxy for setup time is to 

330 suite = loader.loadTestsFromTestCase(q) 

331 unittest.TextTestRunner(verbosity=verbosity).run(suite) 

332 total = time.time() - start 

333 q.time = total 

334 

335 def _setup_answers(self): 

336 self.main() # Run all tests in class just to get that out of the way... 

337 report_cache = {} 

338 for q, _ in self.questions: 

339 if hasattr(q, '_save_cache'): 

340 q()._save_cache() 

341 q._cache['time'] = q.time 

342 report_cache[q.__qualname__] = q._cache 

343 else: 

344 report_cache[q.__qualname__] = {'no cache see _setup_answers in unitgrade2.py':True} 

345 return report_cache 

346 

347 def set_payload(self, payloads, strict=False): 

348 for q, _ in self.questions: 

349 q._cache = payloads[q.__qualname__] 

350 

351 # for item in q.items: 

352 # if q.name not in payloads or item.name not in payloads[q.name]: 

353 # s = f"> Broken resource dictionary submitted to unitgrade for question {q.name} and subquestion {item.name}. Framework will not work." 

354 # if strict: 

355 # raise Exception(s) 

356 # else: 

357 # print(s) 

358 # else: 

359 # item._correct_answer_payload = payloads[q.name][item.name]['payload'] 

360 # item.estimated_time = payloads[q.name][item.name].get("time", 1) 

361 # q.estimated_time = payloads[q.name].get("time", 1) 

362 # if "precomputed" in payloads[q.name][item.name]: # Consider removing later. 

363 # item._precomputed_payload = payloads[q.name][item.name]['precomputed'] 

364 # try: 

365 # if "title" in payloads[q.name][item.name]: # can perhaps be removed later. 

366 # item.title = payloads[q.name][item.name]['title'] 

367 # except Exception as e: # Cannot set attribute error. The title is a function (and probably should not be). 

368 # pass 

369 # # print("bad", e) 

370 # self.payloads = payloads 

371 

372 

373def rm_progress_bar(txt): 

374 # More robust version. Apparently length of bar can depend on various factors, so check for order of symbols. 

375 nlines = [] 

376 for l in txt.splitlines(): 

377 pct = l.find("%") 

378 ql = False 

379 if pct > 0: 

380 i = l.find("|", pct+1) 

381 if i > 0 and l.find("|", i+1) > 0: 

382 ql = True 

383 if not ql: 

384 nlines.append(l) 

385 return "\n".join(nlines) 

386 

387def extract_numbers(txt): 

388 # txt = rm_progress_bar(txt) 

389 numeric_const_pattern = '[-+]? (?: (?: \d* \. \d+ ) | (?: \d+ \.? ) )(?: [Ee] [+-]? \d+ ) ?' 

390 rx = re.compile(numeric_const_pattern, re.VERBOSE) 

391 all = rx.findall(txt) 

392 all = [float(a) if ('.' in a or "e" in a) else int(a) for a in all] 

393 if len(all) > 500: 

394 print(txt) 

395 raise Exception("unitgrade.unitgrade.py: Warning, many numbers!", len(all)) 

396 return all 

397 

398 

399class ActiveProgress(): 

400 def __init__(self, t, start=True, title="my progress bar"): 

401 self.t = t 

402 self._running = False 

403 self.title = title 

404 self.dt = 0.1 

405 self.n = int(np.round(self.t / self.dt)) 

406 # self.pbar = tqdm.tqdm(total=self.n) 

407 if start: 

408 self.start() 

409 

410 def start(self): 

411 self._running = True 

412 self.thread = threading.Thread(target=self.run) 

413 self.thread.start() 

414 self.time_started = time.time() 

415 

416 def terminate(self): 

417 if not self._running: 

418 raise Exception("Stopping a stopped progress bar. ") 

419 self._running = False 

420 self.thread.join() 

421 if hasattr(self, 'pbar') and self.pbar is not None: 

422 self.pbar.update(1) 

423 self.pbar.close() 

424 self.pbar=None 

425 

426 sys.stdout.flush() 

427 return time.time() - self.time_started 

428 

429 def run(self): 

430 self.pbar = tqdm.tqdm(total=self.n, file=sys.stdout, position=0, leave=False, desc=self.title, ncols=100, 

431 bar_format='{l_bar}{bar}| [{elapsed}<{remaining}]') # , unit_scale=dt, unit='seconds'): 

432 

433 for _ in range(self.n-1): # Don't terminate completely; leave bar at 99% done until terminate. 

434 if not self._running: 

435 self.pbar.close() 

436 self.pbar = None 

437 break 

438 

439 time.sleep(self.dt) 

440 self.pbar.update(1) 

441 

442 

443 

444from unittest.suite import _isnotsuite 

445 

446class MySuite(unittest.suite.TestSuite): # Not sure we need this one anymore. 

447 pass 

448 

449def instance_call_stack(instance): 

450 s = "-".join(map(lambda x: x.__name__, instance.__class__.mro())) 

451 return s 

452 

453def get_class_that_defined_method(meth): 

454 for cls in inspect.getmro(meth.im_class): 

455 if meth.__name__ in cls.__dict__: 

456 return cls 

457 return None 

458 

459def caller_name(skip=2): 

460 """Get a name of a caller in the format module.class.method 

461 

462 `skip` specifies how many levels of stack to skip while getting caller 

463 name. skip=1 means "who calls me", skip=2 "who calls my caller" etc. 

464 

465 An empty string is returned if skipped levels exceed stack height 

466 """ 

467 stack = inspect.stack() 

468 start = 0 + skip 

469 if len(stack) < start + 1: 

470 return '' 

471 parentframe = stack[start][0] 

472 

473 name = [] 

474 module = inspect.getmodule(parentframe) 

475 # `modname` can be None when frame is executed directly in console 

476 # TODO(techtonik): consider using __main__ 

477 if module: 

478 name.append(module.__name__) 

479 # detect classname 

480 if 'self' in parentframe.f_locals: 

481 # I don't know any way to detect call from the object method 

482 # XXX: there seems to be no way to detect static method call - it will 

483 # be just a function call 

484 name.append(parentframe.f_locals['self'].__class__.__name__) 

485 codename = parentframe.f_code.co_name 

486 if codename != '<module>': # top level usually 

487 name.append( codename ) # function or a method 

488 

489 ## Avoid circular refs and frame leaks 

490 # https://docs.python.org/2.7/library/inspect.html#the-interpreter-stack 

491 del parentframe, stack 

492 

493 return ".".join(name) 

494 

495def get_class_from_frame(fr): 

496 import inspect 

497 args, _, _, value_dict = inspect.getargvalues(fr) 

498 # we check the first parameter for the frame function is 

499 # named 'self' 

500 if len(args) and args[0] == 'self': 

501 # in that case, 'self' will be referenced in value_dict 

502 instance = value_dict.get('self', None) 

503 if instance: 

504 # return its class 

505 # isinstance(instance, Testing) # is the actual class instance. 

506 

507 return getattr(instance, '__class__', None) 

508 # return None otherwise 

509 return None 

510 

511from typing import Any 

512import inspect, gc 

513 

514def giveupthefunc(): 

515 frame = inspect.currentframe() 

516 code = frame.f_code 

517 globs = frame.f_globals 

518 functype = type(lambda: 0) 

519 funcs = [] 

520 for func in gc.get_referrers(code): 

521 if type(func) is functype: 

522 if getattr(func, "__code__", None) is code: 

523 if getattr(func, "__globals__", None) is globs: 

524 funcs.append(func) 

525 if len(funcs) > 1: 

526 return None 

527 return funcs[0] if funcs else None 

528 

529 

530from collections import defaultdict 

531 

532class UTextResult(unittest.TextTestResult): 

533 nL = 80 

534 show_progress_bar = True 

535 def __init__(self, stream, descriptions, verbosity): 

536 super().__init__(stream, descriptions, verbosity) 

537 self.successes = [] 

538 

539 def printErrors(self) -> None: 

540 # if self.dots or self.showAll: 

541 # self.stream.writeln() 

542 # if hasattr(self, 'cc'): 

543 # self.cc.terminate() 

544 # self.cc_terminate(success=False) 

545 self.printErrorList('ERROR', self.errors) 

546 self.printErrorList('FAIL', self.failures) 

547 

548 def addError(self, test, err): 

549 super(unittest.TextTestResult, self).addFailure(test, err) 

550 self.cc_terminate(success=False) 

551 

552 def addFailure(self, test, err): 

553 super(unittest.TextTestResult, self).addFailure(test, err) 

554 self.cc_terminate(success=False) 

555 # if self.showAll: 

556 # self.stream.writeln("FAIL") 

557 # elif self.dots: 

558 # self.stream.write('F') 

559 # self.stream.flush() 

560 

561 def addSuccess(self, test: unittest.case.TestCase) -> None: 

562 # super().addSuccess(test) 

563 self.successes.append(test) 

564 # super().addSuccess(test) 

565 # hidden = issubclass(item.__class__, Hidden) 

566 # # if not hidden: 

567 # # print(ss, end="") 

568 # # sys.stdout.flush() 

569 # start = time.time() 

570 # 

571 # (current, possible) = item.get_points(show_expected=show_expected, show_computed=show_computed,unmute=unmute, passall=passall, silent=silent) 

572 # q_[j] = {'w': item.weight, 'possible': possible, 'obtained': current, 'hidden': hidden, 'computed': str(item._computed_answer), 'title': item.title} 

573 # tsecs = np.round(time.time()-start, 2) 

574 self.cc_terminate() 

575 

576 

577 

578 def cc_terminate(self, success=True): 

579 if self.show_progress_bar: 

580 tsecs = np.round(self.cc.terminate(), 2) 

581 sys.stdout.flush() 

582 ss = self.item_title_print 

583 print(self.item_title_print + ('.' * max(0, self.nL - 4 - len(ss))), end="") 

584 current = 1 

585 possible = 1 

586 # current == possible 

587 ss = "PASS" if success else "FAILED" 

588 if tsecs >= 0.1: 

589 ss += " (" + str(tsecs) + " seconds)" 

590 print(ss) 

591 

592 

593 def startTest(self, test): 

594 # super().startTest(test) 

595 self.testsRun += 1 

596 # print("Starting the test...") 

597 # show_progress_bar = True 

598 n = 1 

599 j = 1 

600 item_title = self.getDescription(test) 

601 item_title = item_title.split("\n")[0] 

602 

603 # test.countTestCases() 

604 

605 self.item_title_print = "*** q%i.%i) %s" % (n + 1, j + 1, item_title) 

606 estimated_time = 10 

607 nL = 80 

608 # 

609 if self.show_progress_bar: 

610 self.cc = ActiveProgress(t=estimated_time, title=self.item_title_print) 

611 else: 

612 print(self.item_title_print + ('.' * max(0, nL - 4 - len(self.item_title_print))), end="") 

613 

614 self._test = test 

615 

616 def _setupStdout(self): 

617 if self._previousTestClass == None: 

618 total_estimated_time = 2 

619 if hasattr(self.__class__, 'q_title_print'): 

620 q_title_print = self.__class__.q_title_print 

621 else: 

622 q_title_print = "<unnamed test. See unitgrade.py>" 

623 

624 # q_title_print = "some printed title..." 

625 cc = ActiveProgress(t=total_estimated_time, title=q_title_print) 

626 self.cc = cc 

627 

628 def _restoreStdout(self): # Used when setting up the test. 

629 if self._previousTestClass == None: 

630 q_time = self.cc.terminate() 

631 q_time = np.round(q_time, 2) 

632 sys.stdout.flush() 

633 print(self.cc.title, end="") 

634 # start = 10 

635 # q_time = np.round(time.time() - start, 2) 

636 nL = 80 

637 print(" " * max(0, nL - len(self.cc.title)) + ( 

638 " (" + str(q_time) + " seconds)" if q_time >= 0.1 else "")) # if q.name in report.payloads else "") 

639 print("=" * nL) 

640 

641from unittest.runner import _WritelnDecorator 

642from io import StringIO 

643 

644class UTextTestRunner(unittest.TextTestRunner): 

645 def __init__(self, *args, **kwargs): 

646 from io import StringIO 

647 stream = StringIO() 

648 super().__init__(*args, stream=stream, **kwargs) 

649 

650 def _makeResult(self): 

651 # stream = self.stream # not you! 

652 stream = sys.stdout 

653 stream = _WritelnDecorator(stream) 

654 return self.resultclass(stream, self.descriptions, self.verbosity) 

655 

656def wrapper(foo): 

657 def magic(self): 

658 s = "-".join(map(lambda x: x.__name__, self.__class__.mro())) 

659 # print(s) 

660 foo(self) 

661 magic.__doc__ = foo.__doc__ 

662 return magic 

663 

664from functools import update_wrapper, _make_key, RLock 

665from collections import namedtuple 

666_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) 

667 

668def cache(foo, typed=False): 

669 """ Magic cache wrapper 

670 https://github.com/python/cpython/blob/main/Lib/functools.py 

671 """ 

672 maxsize = None 

673 def wrapper(self, *args, **kwargs): 

674 key = self.cache_id() + ("cache", _make_key(args, kwargs, typed)) 

675 if not self._cache_contains(key): 

676 value = foo(self, *args, **kwargs) 

677 self._cache_put(key, value) 

678 else: 

679 value = self._cache_get(key) 

680 return value 

681 return wrapper 

682 

683 

684class UTestCase(unittest.TestCase): 

685 _outcome = None # A dictionary which stores the user-computed outcomes of all the tests. This differs from the cache. 

686 _cache = None # Read-only cache. 

687 _cache2 = None # User-written cache 

688 

689 @classmethod 

690 def reset(cls): 

691 cls._outcome = None 

692 cls._cache = None 

693 cls._cache2 = None 

694 

695 def _get_outcome(self): 

696 if not (self.__class__, '_outcome') or self.__class__._outcome == None: 

697 self.__class__._outcome = {} 

698 return self.__class__._outcome 

699 

700 def _callTestMethod(self, testMethod): 

701 t = time.time() 

702 res = testMethod() 

703 elapsed = time.time() - t 

704 # if res == None: 

705 # res = {} 

706 # res['time'] = elapsed 

707 sd = self.shortDescription() 

708 self._cache_put( (self.cache_id(), 'title'), self._testMethodName if sd == None else sd) 

709 # self._test_fun_output = res 

710 self._get_outcome()[self.cache_id()] = res 

711 self._cache_put( (self.cache_id(), "time"), elapsed) 

712 

713 

714 # This is my base test class. So what is new about it? 

715 def cache_id(self): 

716 c = self.__class__.__qualname__ 

717 m = self._testMethodName 

718 return (c,m) 

719 

720 def unique_cache_id(self): 

721 k0 = self.cache_id() 

722 key = () 

723 for i in itertools.count(): 

724 key = k0 + (i,) 

725 if not self._cache2_contains(key): 

726 break 

727 return key 

728 

729 def __init__(self, *args, **kwargs): 

730 super().__init__(*args, **kwargs) 

731 self._load_cache() 

732 self.cache_indexes = defaultdict(lambda: 0) 

733 

734 def _ensure_cache_exists(self): 

735 if not hasattr(self.__class__, '_cache') or self.__class__._cache == None: 

736 self.__class__._cache = dict() 

737 if not hasattr(self.__class__, '_cache2') or self.__class__._cache2 == None: 

738 self.__class__._cache2 = dict() 

739 

740 def _cache_get(self, key, default=None): 

741 self._ensure_cache_exists() 

742 return self.__class__._cache.get(key, default) 

743 

744 def _cache_put(self, key, value): 

745 self._ensure_cache_exists() 

746 self.__class__._cache2[key] = value 

747 

748 def _cache_contains(self, key): 

749 self._ensure_cache_exists() 

750 return key in self.__class__._cache 

751 

752 def _cache2_contains(self, key): 

753 self._ensure_cache_exists() 

754 return key in self.__class__._cache2 

755 

756 def assertEqualC(self, first: Any, msg: Any = ...) -> None: 

757 id = self.unique_cache_id() 

758 if not self._cache_contains(id): 

759 print("Warning, framework missing key", id) 

760 

761 self.assertEqual(first, self._cache_get(id, first), msg) 

762 self._cache_put(id, first) 

763 

764 def _cache_file(self): 

765 return os.path.dirname(inspect.getfile(self.__class__) ) + "/unitgrade/" + self.__class__.__name__ + ".pkl" 

766 

767 def _save_cache(self): 

768 # get the class name (i.e. what to save to). 

769 cfile = self._cache_file() 

770 if not os.path.isdir(os.path.dirname(cfile)): 

771 os.makedirs(os.path.dirname(cfile)) 

772 

773 if hasattr(self.__class__, '_cache2'): 

774 with open(cfile, 'wb') as f: 

775 pickle.dump(self.__class__._cache2, f) 

776 

777 # But you can also set cache explicitly. 

778 def _load_cache(self): 

779 if self._cache != None: # Cache already loaded. We will not load it twice. 

780 return 

781 # raise Exception("Loaded cache which was already set. What is going on?!") 

782 cfile = self._cache_file() 

783 print("Loading cache from", cfile) 

784 if os.path.exists(cfile): 

785 with open(cfile, 'rb') as f: 

786 data = pickle.load(f) 

787 self.__class__._cache = data 

788 else: 

789 print("Warning! data file not found", cfile) 

790 

791def hide(func): 

792 return func 

793 

794def makeRegisteringDecorator(foreignDecorator): 

795 """ 

796 Returns a copy of foreignDecorator, which is identical in every 

797 way(*), except also appends a .decorator property to the callable it 

798 spits out. 

799 """ 

800 def newDecorator(func): 

801 # Call to newDecorator(method) 

802 # Exactly like old decorator, but output keeps track of what decorated it 

803 R = foreignDecorator(func) # apply foreignDecorator, like call to foreignDecorator(method) would have done 

804 R.decorator = newDecorator # keep track of decorator 

805 # R.original = func # might as well keep track of everything! 

806 return R 

807 

808 newDecorator.__name__ = foreignDecorator.__name__ 

809 newDecorator.__doc__ = foreignDecorator.__doc__ 

810 # (*)We can be somewhat "hygienic", but newDecorator still isn't signature-preserving, i.e. you will not be able to get a runtime list of parameters. For that, you need hackish libraries...but in this case, the only argument is func, so it's not a big issue 

811 return newDecorator 

812 

813hide = makeRegisteringDecorator(hide) 

814 

815def methodsWithDecorator(cls, decorator): 

816 """ 

817 Returns all methods in CLS with DECORATOR as the 

818 outermost decorator. 

819 

820 DECORATOR must be a "registering decorator"; one 

821 can make any decorator "registering" via the 

822 makeRegisteringDecorator function. 

823 

824 import inspect 

825 ls = list(methodsWithDecorator(GeneratorQuestion, deco)) 

826 for f in ls: 

827 print(inspect.getsourcelines(f) ) # How to get all hidden questions. 

828 """ 

829 for maybeDecorated in cls.__dict__.values(): 

830 if hasattr(maybeDecorated, 'decorator'): 

831 if maybeDecorated.decorator == decorator: 

832 print(maybeDecorated) 

833 yield maybeDecorated 

834