Browse Source

Check the size of fillchar passed to str methods (ljust, rjust) (#236)

* check the size of fillchar passed to str methods (ljust, rjust)

* count characters using u8_length instead of size
albertexye 1 year ago
parent
commit
b1115a4c8f
2 changed files with 12 additions and 0 deletions
  1. 2 0
      src/pocketpy.cpp
  2. 10 0
      tests/04_str.py

+ 2 - 0
src/pocketpy.cpp

@@ -734,6 +734,7 @@ void init_builtins(VM* _vm) {
         int delta = width - self.u8_length();
         int delta = width - self.u8_length();
         if(delta <= 0) return args[0];
         if(delta <= 0) return args[0];
         const Str& fillchar = CAST(Str&, args[2]);
         const Str& fillchar = CAST(Str&, args[2]);
+        if (fillchar.u8_length() != 1) vm->TypeError("The fill character must be exactly one character long");
         SStream ss;
         SStream ss;
         ss << self;
         ss << self;
         for(int i=0; i<delta; i++) ss << fillchar;
         for(int i=0; i<delta; i++) ss << fillchar;
@@ -747,6 +748,7 @@ void init_builtins(VM* _vm) {
         int delta = width - self.u8_length();
         int delta = width - self.u8_length();
         if(delta <= 0) return args[0];
         if(delta <= 0) return args[0];
         const Str& fillchar = CAST(Str&, args[2]);
         const Str& fillchar = CAST(Str&, args[2]);
+        if (fillchar.u8_length() != 1) vm->TypeError("The fill character must be exactly one character long");
         SStream ss;
         SStream ss;
         for(int i=0; i<delta; i++) ss << fillchar;
         for(int i=0; i<delta; i++) ss << fillchar;
         ss << self;
         ss << self;

+ 10 - 0
tests/04_str.py

@@ -157,8 +157,18 @@ assert b[5:2:-2] == [',', 'l']
 a = '123'
 a = '123'
 assert a.rjust(5) == '  123'
 assert a.rjust(5) == '  123'
 assert a.rjust(5, '0') == '00123'
 assert a.rjust(5, '0') == '00123'
+try:
+    a.rjust(5, '00')
+    exit(1)
+except TypeError:
+    pass
 assert a.ljust(5) == '123  '
 assert a.ljust(5) == '123  '
 assert a.ljust(5, '0') == '12300'
 assert a.ljust(5, '0') == '12300'
+try:
+    a.ljust(5, '00')
+    exit(1)
+except TypeError:
+    pass
 
 
 assert '\x30\x31\x32' == '012'
 assert '\x30\x31\x32' == '012'