1 __all__ = ['accepts', 'returns', 'yields', 'TypeCheckError', 'Length', 'Empty'
2 ,'TypeSignatureError', 'And', 'Any', 'Class', 'Exact', 'HasAttr'
3 ,'IsAllOf', 'IsCallable', 'IsIterable', 'IsNoneOf', 'IsOneOf'
4 ,'IsOnlyOneOf', 'Not', 'Or', 'Self', 'Xor', 'YieldSeq'
5 ,'register_type', 'is_registered_type', 'unregister_type'
6 ,'Function']
7
8 import inspect
9 import types
10
11 from types import GeneratorType, FunctionType, MethodType, ClassType, TypeType
19
28 raise NotImplementedError("Incomplete _TC_Exception subclass (%s)" % str(self.__class__))
29
32
39
41 m = None
42 if self.right is not None:
43 m = ", expected %d" % self.right
44 return "length was %d%s" % (self.wrong, m or "")
45
52
54 return "expected %s, got %s" % (self.right, self.wrong)
55
58 self.inner = inner_exception
59
61 try:
62 return ", " + self.inner.error_message()
63 except:
64 print "'%s'" % self.inner.message
65 raw_input()
66 raise
67
69 - def __init__(self, index, inner_exception):
73
76
87
89 raise NotImplementedError("Incomplete _TC_DictError subclass: " + str(self.__class__))
90
92 - def __init__(self, key, inner_exception):
96
99
101 - def __init__(self, key, val, inner_exception):
105
108
110 - def __init__(self, yield_no, inner_exception):
114
116 raise RuntimeError("_TC_GeneratorError.message should never be called")
117
122
129
131 - def __init__(self, attr, inner_exception):
134
137
140 return "missing attribute %s" % self.attr
141
148
151
153 plural = "s"
154 if self.expected == 1:
155 plural = ""
156
157 return "only expected the generator to yield %d time%s" % (self.expected, plural)
158
162 - def __init__(self, matched_conds, inner_exception):
169
171 if self.matched_conds == 0:
172 m = "neither assertion"
173 else:
174 m = "both assertions"
175
176 return _TC_NestedError.error_message(self) + " (matched %s)" % m
177
179 - def __init__(self, checking_func, obj):
180 self.checking_func = checking_func
181 self.rejected_obj = obj
182
184 return " was rejected by %s" % self.checking_func
185
188
191 self.wrong = wrong
192 self.right = right
193
195 return "expected %s, got %s" % (self.right, self.wrong)
196
200 raise NotImplementedError("Incomplete _TS_Exception subclass (%s)" % str(self.__class__))
201
212
214 return "the signature type %s does not match %s" % (str(self.types), str(self.parameters))
215
221
223 return "the keyword '%s' in the signature is not in the function" % self.keyword
224
230
232 return "an extra positional type has been supplied"
233
239
241 return "parameter '%s' lacks a type" % self.parameter
242
246 - def __init__(self, parameter, kw_type, pos_type):
247 _TS_Exception.__init__(self, parameter, kw_type, pos_type)
248
249 self.parameter = parameter
250 self.kw_type = kw_type
251 self.pos_type = pos_type
252
254 return "parameter '%s' is provided two types (%s and %s)" % (self.parameter, str(self.kw_type), str(self.pos_type))
255
256
257
258
259
260
261 _hooks = ("__typesig__", "__startchecking__", "__stopchecking__", "__switchchecking__")
262
263 _registered_types = set()
264 _registered_hooks = dict([(_h, set()) for _h in _hooks])
267 if not isinstance(reg_type, (types.ClassType, types.TypeType)):
268 raise ValueError("registered types must be classes or types")
269
270 valid = False
271 for hook in _hooks:
272 if hasattr(reg_type, hook):
273 getattr(_registered_hooks[hook], add_remove)(reg_type)
274 valid = True
275
276 if valid:
277 getattr(_registered_types, add_remove)(reg_type)
278 else:
279 raise ValueError("registered types must have at least one of the following methods: " + ", ".join(_hooks))
280
283
286
289
290
291
292
293 -def Type(obj):
294
295
296
297 for reg_type in _registered_hooks['__typesig__']:
298 v = reg_type.__typesig__(obj)
299 if v is not None:
300 return v
301
302 raise AssertionError("Object is of type '%s'; not a type" % str(type(obj)))
303
305 attr = '__%schecking__' % start_stop
306
307 for reg_type in _registered_hooks[attr]:
308 getattr(reg_type, attr)(*args)
309
312
315
317 for reg_type in _registered_types:
318 if hasattr(reg_type, '__switchchecking__'):
319 getattr(reg_type, '__switchchecking__')(from_func, to_func)
320 else:
321 if hasattr(reg_type, '__stopchecking__'):
322 getattr(reg_type, '__stopchecking__')(from_func)
323 if hasattr(reg_type, '__startchecking__'):
324 getattr(reg_type, '__startchecking__')(to_func)
325
331 if isinstance(obj, types.InstanceType):
332 return obj.__class__
333 elif isinstance(obj, dict):
334 if len(obj) == 0:
335 return {}
336
337 key_types = set()
338 val_types = set()
339
340 for (k,v) in obj.items():
341 key_types.add( calculate_type(k) )
342 val_types.add( calculate_type(v) )
343
344 if len(key_types) == 1:
345 key_types = key_types.pop()
346 else:
347 key_types = Or(*key_types)
348
349 if len(val_types) == 1:
350 val_types = val_types.pop()
351 else:
352 val_types = Or(*val_types)
353
354 return {key_types: val_types}
355 elif isinstance(obj, tuple):
356 return tuple([calculate_type(t) for t in obj])
357 elif isinstance(obj, list):
358 length = len(obj)
359 if length == 0:
360 return []
361 obj = [calculate_type(o) for o in obj]
362
363 partitions = [1]
364 partitions.extend([i for i in range(2, int(length/2)+1) if length%i==0])
365 partitions.append(length)
366
367 def evaluate(items_per):
368 parts = length / items_per
369
370 for i in range(0, parts):
371 for j in range(0, items_per):
372 if obj[items_per * i + j] != obj[j]:
373 raise StopIteration
374 return obj[0:items_per]
375
376 for items_per in partitions:
377 try:
378 return evaluate(items_per)
379 except StopIteration:
380 continue
381 else:
382 return type(obj)
383
389 return type(self).name + '(' + ', '.join(sorted(repr(t) for t in self._types)) + ')'
390
391 __str__ = __repr__
392
394 return not self != other
395
397 return not self == other
398
400 raise NotImplementedError("Incomplete CheckType subclass: %s" % self.__class__)
401
403 raise NotImplementedError("Incomplete CheckType subclass: %s" % self.__class__)
404
405 @classmethod
407 if isinstance(obj, CheckType):
408 return obj
409
411 name = "Single"
412
414 if not isinstance(type, (types.ClassType, types.TypeType)):
415 raise TypeError("Cannot type-check a %s" % type(type))
416 else:
417 self.type = type
418
419 self._types = [self.type]
420
422 if not isinstance(to_check, self.type):
423 raise _TC_TypeError(to_check, self.type)
424
426 if other.__class__ is not self.__class__:
427 return False
428 return self.type == other.type
429
431 return hash(str(hash(self.__class__)) + str(hash(self.type)))
432
433
434
436 return repr(self.type)
437
438 @classmethod
440 if isinstance(obj, (types.ClassType, types.TypeType)):
441 return Single(obj)
442
445 name = "Empty"
446
448 if not hasattr(type, '__len__'):
449 raise TypeError("Can only assert emptyness for types with __len__ methods")
450
451 Single.__init__(self, type)
452
461
462 -class Dict(CheckType):
463 name = "Dict"
464
466 self.__check_key = Type(key)
467 self.__check_val = Type(val)
468
469 self.type = {key: val}
470 self._types = [key, val]
471
473 if not isinstance(to_check, types.DictType):
474 raise _TC_TypeError(to_check, self.type)
475
476 for (k, v) in to_check.items():
477
478 try:
479 check_type(self.__check_key, func, k)
480 except _TC_Exception, inner:
481 raise _TC_KeyError(k, inner)
482
483
484 try:
485 check_type(self.__check_val, func, v)
486 except _TC_Exception, inner:
487 raise _TC_KeyValError(k, v, inner)
488
490 if other.__class__ is not self.__class__:
491 return False
492 return self.type == other.type
493
495 cls = self.__class__
496 key = self.__check_key
497 val = self.__check_val
498
499 def strhash(obj):
500 return str(hash(obj))
501
502 return hash(''.join(map(strhash, [cls, key, val])))
503
504 @classmethod
506 if isinstance(obj, dict):
507 if len(obj) == 0:
508 return Empty(dict)
509 return Dict(obj.keys()[0], obj.values()[0])
510
511
512 -class List(CheckType):
513 name = "List"
514
516 self._types = [Type(t) for t in type]
517 self.type = [t.type for t in self._types]
518
520 if not isinstance(to_check, list):
521 raise _TC_TypeError(to_check, self.type)
522 if len(to_check) % len(self._types):
523 raise _TC_LengthError(len(to_check))
524
525
526
527
528
529
530
531
532
533
534
535
536 pat_len = len(self._types)
537 type_tuples = [(i, val, self._types[i % pat_len]) for (i, val)
538 in enumerate(to_check)]
539 for (i, val, type) in type_tuples:
540 try:
541 check_type(type, func, val)
542 except _TC_Exception, e:
543 raise _TC_IndexError(i, e)
544
546 if other.__class__ is not self.__class__:
547 return False
548
549 if len(self._types) != len(other._types):
550 return False
551
552 for (s, o) in zip(self._types, other._types):
553 if s != o:
554 return False
555 return True
556
558 def strhash(obj):
559 return str(hash(obj))
560
561 return hash(''.join(map(strhash, [self.__class__] + self._types)))
562
563 @classmethod
565 if isinstance(obj, list):
566 if len(obj) == 0:
567 return Empty(list)
568 return List(*obj)
569
572 name = "Tuple"
573
575 List.__init__(self, *type)
576
577 self.type = tuple(self.type)
578
580
581
582 if not isinstance(to_check, types.TupleType) or len(to_check) != len(self._types):
583 raise _TC_TypeError(to_check, self.type)
584
585 for (i, (val, type)) in enumerate(zip(to_check, self._types)):
586 try:
587 check_type(type, func, val)
588 except _TC_Exception, inner:
589 raise _TC_IndexError(i, inner)
590
591 @classmethod
593 if isinstance(obj, tuple):
594 return Tuple(*obj)
595
597
598
599
600
601 __mapping_stack = []
602
603
604
605 __active_mapping = None
606
607
608 __gen_mappings = {}
609
612
614 return "TypeVariable(%s)" % self.type
615
616 __repr__ = __str__
617
619 return hash(''.join([str(o) for o in self.__class__
620 , hash(type(self.type))
621 , hash(self.type)]))
622
624 if self.__class__ is not other.__class__:
625 return False
626 return type(self.type) is type(other.type) and self.type == other.type
627
645
646 @classmethod
648 if isinstance(obj, basestring):
649 return cls(obj)
650
651 @classmethod
660
661 @classmethod
674
675 @classmethod
683
686 self._func = func
687 self.type = self
688
689 @classmethod
691 if isinstance(obj, (FunctionType, MethodType)):
692 return cls(obj)
693
694
695 if type(obj) not in (types.ClassType, type) and callable(obj):
696 return cls(obj)
697
699 if False == self._func(to_check):
700 raise _TC_FunctionError(self._func, to_check)
701
703 return "Function(%s)" % self._func
704
707
709 if self.__class__ is not other.__class__:
710 return False
711 return self._func is other._func
712
714 return hash(str(self.__class__) + str(hash(self._func)))
715
716
717 for c in (CheckType, List, Tuple, Dict, Single, TypeVariables, Function):
718 register_type(c)
719
720
721
722
723
724
725
726 -class Any(CheckType):
727 name = "Any"
728
731
734
737
738 __repr__ = __str__
739
740
742 return other.__class__ is self.__class__
743
745 return hash(self.__class__)
746
749 - def __init__(self, first_type, second_type, *types):
750 self._types = set()
751
752 for t in (first_type, second_type)+types:
753 if type(t) is type(self):
754 self._types.update(t._types)
755 else:
756 self._types.add(Type(t))
757
758 if len(self._types) < 2:
759 raise TypeError("there must be at least 2 distinct parameters to __init__()")
760
761 self.type = self
762
764 if other.__class__ is not self.__class__:
765 return False
766
767 return self._types == other._types
768
770 return hash(str(hash(self.__class__)) + str(hash(frozenset(self._types))))
771
784
785 -class And(_Boolean):
794
796 name = "Not"
797
798
799
800 - def __init__(self, first_type, *types):
801 self._types = set([Type(t) for t in (first_type,)+types])
802
803 self.type = self
804
812
813 -class Xor(_Boolean):
831
835
837 return "IsCallable()"
838
839 __repr__ = __str__
840
841
842
844 return id(self.__class__)
845
847 return self.__class__ is other.__class__
848
850 if not callable(to_check):
851 raise _TC_TypeError(to_check, 'a callable')
852
855 attr_sets = {list: [], dict: {}}
856
857 for (arg_1, arg_2) in ((set_1, set_2), (set_2, set_1)):
858 for t in (list, dict):
859 if isinstance(arg_1, t):
860 attr_sets[t] = arg_1
861 if isinstance(arg_2, t):
862 raise TypeError("can only have one list and/or one dict")
863
864 self._attr_types = dict.fromkeys(attr_sets[list], Any())
865
866 for (attr, typ) in attr_sets[dict].items():
867 self._attr_types[attr] = Type(typ)
868
878
880 if self.__class__ is not other.__class__:
881 return False
882 return self._attr_types == other._attr_types
883
885 return hash(str(hash(self.__class__)) + str(hash(str(self._attr_types))))
886
888 any_type = []
889 spec_type = {}
890
891 any = Any()
892
893 for (attr, typ) in self._attr_types.items():
894 if typ == any:
895 any_type.append(attr)
896 else:
897 spec_type[attr] = typ
898
899 msg = [t for t in (any_type, spec_type) if len(t)]
900
901 return "HasAttr(" + ', '.join(map(str, msg)) + ")"
902
903 __repr__ = __str__
904
908
910 return self.__class__ is other.__class__
911
912
913
915 return id(self.__class__)
916
918 return "IsIterable()"
919
920 __repr__ = __str__
921
923 if not (hasattr(to_check, '__iter__') and callable(to_check.__iter__)):
924 raise _TC_TypeError(to_check, "an iterable")
925
927 _index_map = {}
928
929 - def __init__(self, type_1, type_2, *types):
930 self.type = self
931
932 self._type = [type_1, type_2] + list(types)
933 self._types = [Type(t) for t in self._type]
934
937
939 return "YieldSeq(" + ", ".join(map(str, self._type)) + ")"
940
941 __repr__ = __str__
942
944 if self.__class__ is not other.__class__:
945 return False
946 return self._types == other._types
947
949 return hash(str(self.__class__) + str([hash(t) for t in self._types]))
950
951
952
953 @classmethod
955 if isinstance(gen, GeneratorType):
956 cls._index_map[gen] = {}
957
958 @classmethod
962
964 index_map = self.__class__._index_map
965
966
967 if self not in index_map[gen]:
968 index_map[gen][self] = -1
969 index = index_map[gen]
970
971 if index[self] >= len(self._types)-1:
972 raise _TC_YieldCountError(len(self._types))
973
974 index[self] += 1
975 check_type(self._types[index[self]], gen, to_check)
976
977 register_type(YieldSeq)
978
979 -class Exact(CheckType):
981 self.type = self
982 self._obj = obj
983
985 try:
986 obj_hash = str(hash(self._obj))
987 except TypeError:
988 obj_hash = str(type(self._obj)) + str(self._obj)
989
990 return hash(str(self.__class__) + obj_hash)
991
993 if self.__class__ is not other.__class__:
994 return False
995 return self._obj == other._obj
996
998 if self._obj != to_check:
999 raise _TC_ExactError(to_check, self._obj)
1000
1003 self.type = self
1004 self._length = int(length)
1005
1007 return hash(str(self.__class__) + str(self._length))
1008
1010 if self.__class__ is not other.__class__:
1011 return False
1012 return self._length == other._length
1013
1015 try:
1016 length = len(to_check)
1017 except TypeError:
1018 raise _TC_TypeError(to_check, "something with a __len__ method")
1019
1020 if length != self._length:
1021 raise _TC_LengthError(length, self._length)
1022
1023 import sys
1026 self.type = self
1027 self.class_name = class_name
1028 self.class_obj = None
1029 self._frame = sys._getframe(1)
1030
1032 return hash(str(self.__class__) + self.class_name)
1033
1035 return "Class('%s')" % self.class_name
1036
1037 __repr__ = __str__
1038
1040 if self.__class__ is not other.__class__:
1041 return False
1042 return self.class_name == other.class_name
1043
1045 if self.class_obj is None:
1046 class_name = self.class_name
1047 frame = self._frame
1048
1049 for f_dict in (frame.f_locals, frame.f_globals):
1050 if class_name in frame.f_locals:
1051 if self is not frame.f_locals[class_name]:
1052 self.class_obj = frame.f_locals[class_name]
1053 self._frame = None
1054 break
1055 else:
1056 raise NameError("name '%s' is not defined" % class_name)
1057
1058 if not isinstance(to_check, self.class_obj):
1059 raise _TC_TypeError(to_check, self.class_obj)
1060
1062 bad_members = dict.fromkeys(['__class__', '__new__', '__init__'], True)
1063
1065 if len(types) == 0:
1066 raise TypeError("Must supply at least one type to __init__()")
1067
1068 self.type = self
1069
1070 self._cache = set()
1071 self._interface = set()
1072 self._instances = set()
1073 for t in types:
1074 self.add_instance(t)
1075
1076 self._calculate_interface()
1077
1081
1083 return list(self._instances)
1084
1086 return list(self._interface)
1087
1089 return instance in self._instances
1090
1092 if isinstance(instance, self.__class__):
1093 for inst in instance.instances():
1094 self._instances.add(inst)
1095 self._cache.add(inst)
1096 elif isinstance(instance, (ClassType, TypeType)):
1097 self._instances.add(instance)
1098 self._cache.add(instance)
1099 else:
1100 raise TypeError("All instances must be classes or types")
1101
1103 if isinstance(other, self.__class__):
1104 new_instances = other.instances()
1105 else:
1106 new_instances = other
1107
1108 self._instances.update(new_instances)
1109 self._cache.update(new_instances)
1110 self._calculate_interface()
1111
1113 bad_members = self.bad_members
1114
1115 for instance in self._instances:
1116 inst_attrs = []
1117
1118 for attr, obj in instance.__dict__.items():
1119 if callable(obj) and attr not in bad_members:
1120 inst_attrs.append(attr)
1121
1122 if len(self._interface) == 0:
1123 self._interface = set(inst_attrs)
1124 else:
1125 self._interface.intersection_update(inst_attrs)
1126
1128 if to_check.__class__ in self._cache:
1129 return
1130
1131 for method in self._interface:
1132 if not hasattr(to_check, method):
1133 raise _TC_MissingAttrError(method)
1134
1135 attr = getattr(to_check, method)
1136 if not callable(attr):
1137 raise _TC_AttrError(method, _TC_TypeError(attr, IsCallable()))
1138
1139 self._cache.add(to_check.__class__)
1140
1142 if self.__class__ is not other.__class__:
1143 return False
1144 return self._instances == other._instances
1145
1147 return hash(str(self.__class__) + str(hash(frozenset(self._instances))))
1148
1151
1153 return 'Typeclass(' + ', '.join(map(str, self._instances)) + ')'
1154
1159
1160
1161
1162
1163 IsOneOf = Or
1164 IsAllOf = And
1165 IsNoneOf = Not
1166 IsOnlyOneOf = Xor
1174
1176 - def __init__(self, prefix, bad_object, exception):
1185
1187 return self.__message
1188
1191 Exception.__init__(self, internal_exc)
1192
1193 self.internal = internal_exc
1194 self.__message = internal_exc.error_message()
1195
1197 return self.__message
1198
1202 if isinstance(obj, list):
1203 return tuple(_rec_tuple(o) for o in obj)
1204 return obj
1205
1207 if not isinstance(obj, (list, tuple)):
1208 return obj
1209
1210 if len(obj) == 1:
1211 return '(%s,)' % obj
1212
1213 return '(' + ', '.join(_rec_tuple_str(o) for o in obj) + ')'
1214
1216 sig_args = list()
1217 dic_args = list()
1218
1219 for obj in posargs:
1220 if isinstance(obj, list):
1221 rts = _rec_tuple_str(obj)
1222
1223 sig_args.append(rts)
1224 dic_args.append((_rec_tuple(obj), rts))
1225 else:
1226 sig_args.append(str(obj))
1227 dic_args.append(('"%s"' % obj, obj))
1228
1229 func_code = ''
1230 if varargs:
1231 dic_args.append(('"%s"' % varargs, varargs))
1232 sig_args.append('*' + varargs)
1233 func_code = '\n\t%s = list(%s)' % (varargs, varargs)
1234 if varkw:
1235 dic_args.append(('"%s"' % varkw, varkw))
1236 sig_args.append('**' + varkw)
1237
1238 func_name = func.func_name + '_'
1239 while func_name in dic_args:
1240 func_name += '_'
1241
1242 func_def = 'def %s(' % func.func_name
1243 func_return = func_code \
1244 + '\n\treturn {' \
1245 + ', '.join('%s: %s' % kv for kv in dic_args) \
1246 + '}'
1247
1248 locals = {}
1249 exec func_def + ','.join(sig_args) + '):' + func_return in locals
1250 func = locals[func.func_name]
1251 func.func_defaults = defaults
1252 return func
1253
1255 if not isinstance(ref, (list, tuple)):
1256 return
1257 if not isinstance(obj, (list, tuple)):
1258 raise _TS_TupleError(ref, obj)
1259
1260 if len(ref) != len(obj):
1261 raise _TS_TupleError(ref, obj)
1262
1263 try:
1264 for r, o in zip(ref, obj):
1265 _validate_tuple(r, o)
1266 except _TS_TupleError:
1267 raise _TS_TupleError(ref, obj)
1268
1270 vargs = list(vargs)
1271 kwargs = dict(kwargs)
1272
1273
1274 param_value = dict()
1275
1276
1277 if len(params) < len(vargs) and varg_name is None:
1278 raise _TS_ExtraPositionalError(vargs[len(params)])
1279
1280 if len(params) > len(vargs) and len(kwargs) == 0:
1281 raise _TS_MissingTypeError(params[len(vargs)])
1282
1283
1284 if len(vargs):
1285 for p, a in zip(params, vargs):
1286
1287 _validate_tuple(p, a)
1288 param_value[_rec_tuple(p)] = a
1289
1290
1291 if len(kwargs) > 0:
1292
1293 params = set([k for k in params if k not in param_value])
1294 if kwarg_name and kwarg_name not in param_value:
1295 params.add(kwarg_name)
1296 if varg_name and varg_name not in param_value:
1297 params.add(varg_name)
1298
1299
1300 no_double_star = kwarg_name is None
1301
1302
1303
1304 if len(params) == 0 and no_double_star:
1305 raise _TS_ExtraKeywordError(kwargs.keys()[0])
1306
1307
1308 for p, a in kwargs.items():
1309 if p in param_value:
1310 raise _TS_TwiceTypedError(p, a, param_value[p])
1311 if p not in params and no_double_star:
1312 raise _TS_ExtraKeywordError(p)
1313
1314
1315 _validate_tuple(p, a)
1316
1317
1318 params.remove(p)
1319 param_value[p] = a
1320
1321
1322
1323 if len(params):
1324 raise _TS_MissingTypeError(params.pop())
1325
1326 return param_value
1327
1329 def fake_function(*vargs, **kwargs):
1330
1331
1332
1333
1334
1335 start_checking(func)
1336
1337
1338
1339 try:
1340 fake_function.__check_args(vargs, kwargs)
1341 result = func(*vargs, **kwargs)
1342 except:
1343 stop_checking(func)
1344 raise
1345
1346 return fake_function.__check_result(func, result)
1347
1348
1349
1350 def _pass_args(vargs, kwargs):
1351 pass
1352 def _pass_result(func, result):
1353 stop_checking(func)
1354 return result
1355
1356 fake_function.__check_args = _pass_args
1357 fake_function.__check_result = _pass_result
1358 fake_function.__wrapped_func = func
1359
1360
1361
1362 fake_function.__module__ = func.__module__
1363 fake_function.__name__ = func.__name__
1364 fake_function.__doc__ = func.__doc__
1365
1366 return fake_function
1367
1372
1373 def decorator(func):
1374 if hasattr(func, '__wrapped_func'):
1375 if hasattr(func, 'type_args'):
1376 raise RuntimeError('Cannot use the same typecheck_* function more than once on the same function')
1377 wrapped_func = func.__wrapped_func
1378 else:
1379 wrapped_func = func
1380
1381 param_list, varg_name, kwarg_name, defaults = inspect.getargspec(wrapped_func)
1382 args_to_params = _gen_arg_to_param(wrapped_func, (param_list, varg_name, kwarg_name, defaults))
1383
1384 try:
1385 param_types = _param_to_type((param_list, varg_name, kwarg_name), v_sig, kw_sig)
1386 except _TS_Exception, e:
1387 raise TypeSignatureError(e)
1388
1389
1390
1391 if varg_name:
1392 if not isinstance(param_types[varg_name], list):
1393 param_types[varg_name] = [param_types[varg_name]]
1394
1395 if kwarg_name:
1396 if not isinstance(param_types[kwarg_name], dict):
1397 param_types[kwarg_name] = {str: param_types[kwarg_name]}
1398
1399
1400
1401
1402
1403 check_param_types = dict()
1404 for k, v in param_types.items():
1405 check_param_types[k] = Type(v)
1406
1407 def __check_args(__vargs, __kwargs):
1408
1409
1410 if enable_checking:
1411 arg_dict = args_to_params(*__vargs, **__kwargs)
1412
1413
1414 try:
1415 for name, val in arg_dict.items():
1416 check_type(check_param_types[name], wrapped_func, val)
1417 except _TC_Exception, e:
1418 str_name = _rec_tuple_str(name)
1419 raise TypeCheckError("Argument %s: " % str_name, val, e)
1420
1421 if hasattr(func, '__check_result'):
1422
1423
1424 fake_function = func
1425 else:
1426
1427 fake_function = _make_fake_function(func)
1428
1429
1430 fake_function.__check_args = __check_args
1431
1432
1433 fake_function.type_args = param_types
1434
1435 return fake_function
1436 return decorator
1437
1438
1439 -def _decorator(signature, conflict_field, twice_field, check_result_func):
1440 def decorator(func):
1441 if hasattr(func, '__check_result'):
1442
1443
1444 if hasattr(func, conflict_field):
1445 raise RuntimeError("Cannot use typecheck_return and typecheck_yield on the same function")
1446 elif hasattr(func, twice_field):
1447 raise RuntimeError('Cannot use the same typecheck_* function more than once on the same function')
1448
1449 fake_function = func
1450 else:
1451 fake_function = _make_fake_function(func)
1452
1453 setattr(fake_function, twice_field, signature)
1454 fake_function.__check_result = check_result_func
1455 return fake_function
1456 return decorator
1457
1473 return _decorator(signature, 'type_yield', 'type_return', __check_return)
1474
1476 - def __init__(self, real_gen, signature):
1477
1478
1479
1480
1481 self.type_yield = signature
1482
1483 self.__yield_no = 0
1484 self.__real_gen = real_gen
1485 self.__sig_types = Type(signature)
1486 self.__needs_stopping = True
1487
1489 gen = self.__real_gen
1490
1491 self.__yield_no += 1
1492
1493 try:
1494 return_vals = gen.next()
1495 except StopIteration:
1496 if self.__needs_stopping:
1497 stop_checking(gen)
1498 self.__needs_stopping = False
1499 raise
1500
1501 if enable_checking:
1502 try:
1503 check_type(self.__sig_types, gen, return_vals)
1504 except _TC_Exception, e:
1505
1506
1507 middle_exc = _TC_GeneratorError(self.__yield_no, e)
1508 raise TypeCheckError("", return_vals, middle_exc)
1509
1510
1511 return return_vals
1512
1514 if self.__needs_stopping:
1515 stop_checking(self.__real_gen)
1516
1518 if len(signature) == 1:
1519 signature = signature[0]
1520
1521 def __check_yield(func, gen):
1522
1523 if not isinstance(gen, types.GeneratorType):
1524 stop_checking(func)
1525 raise TypeError("typecheck_yield only works for generators")
1526
1527
1528
1529
1530
1531
1532 switch_checking(func, gen)
1533
1534
1535 return Fake_generator(gen, signature)
1536 return _decorator(signature, 'type_return', 'type_yield', __check_yield)
1537
1538 _null_decorator = lambda *args, **kwargs: lambda f: f
1539 typecheck = _null_decorator
1540 accepts = _null_decorator
1541 returns = _null_decorator
1542 yields = _null_decorator
1555
1556 import os
1557 if "PYTHONTYPECHECK" in os.environ:
1558 enable_typechecking()
1559