Просмотр исходного кода

`import` can be used in local scope now

blueloveTH 2 лет назад
Родитель
Сommit
eb5be9ba41
4 измененных файлов с 32 добавлено и 19 удалено
  1. 1 0
      include/pocketpy/expr.h
  2. 3 4
      src/compiler.cpp
  3. 18 13
      src/expr.cpp
  4. 10 2
      tests/30_import.py

+ 1 - 0
include/pocketpy/expr.h

@@ -65,6 +65,7 @@ struct CodeEmitContext{
     int add_varname(StrName name);
     int add_const(PyObject* v);
     int add_func_decl(FuncDecl_ decl);
+    void emit_store_name(NameScope scope, StrName name, int line);
 };
 
 struct NameExpr: Expr{

+ 3 - 4
src/compiler.cpp

@@ -477,7 +477,6 @@ __SUBSCR_END:
     // import a [as b]
     // import a [as b], c [as d]
     void Compiler::compile_normal_import() {
-        if(name_scope() != NAME_GLOBAL) SyntaxError("import statement should be used in global scope");
         do {
             consume(TK("@id"));
             Str name = prev().str();
@@ -486,7 +485,7 @@ __SUBSCR_END:
                 consume(TK("@id"));
                 name = prev().str();
             }
-            ctx()->emit(OP_STORE_GLOBAL, StrName(name).index, prev().line);
+            ctx()->emit_store_name(name_scope(), StrName(name), prev().line);
         } while (match(TK(",")));
         consume_end_stmt();
     }
@@ -499,7 +498,6 @@ __SUBSCR_END:
     // from .a.b import c [as d]
     // from xxx import *
     void Compiler::compile_from_import() {
-        if(name_scope() != NAME_GLOBAL) SyntaxError("import statement should be used in global scope");
         int dots = 0;
 
         while(true){
@@ -538,6 +536,7 @@ __EAT_DOTS_END:
         consume(TK("import"));
 
         if (match(TK("*"))) {
+            if(name_scope() != NAME_GLOBAL) SyntaxError("from <module> import * can only be used in global scope");
             // pop the module and import __all__
             ctx()->emit(OP_POP_IMPORT_STAR, BC_NOARG, prev().line);
             consume_end_stmt();
@@ -553,7 +552,7 @@ __EAT_DOTS_END:
                 consume(TK("@id"));
                 name = prev().str();
             }
-            ctx()->emit(OP_STORE_GLOBAL, StrName(name).index, prev().line);
+            ctx()->emit_store_name(name_scope(), StrName(name), prev().line);
         } while (match(TK(",")));
         ctx()->emit(OP_POP_TOP, BC_NOARG, BC_KEEPLINE);
         consume_end_stmt();

+ 18 - 13
src/expr.cpp

@@ -1,4 +1,5 @@
 #include "pocketpy/expr.h"
+#include "pocketpy/codeobject.h"
 
 namespace pkpy{
 
@@ -96,6 +97,21 @@ namespace pkpy{
         return co->func_decls.size() - 1;
     }
 
+    void CodeEmitContext::emit_store_name(NameScope scope, StrName name, int line){
+        switch(scope){
+            case NAME_LOCAL:
+                emit(OP_STORE_FAST, add_varname(name), line);
+                break;
+            case NAME_GLOBAL:
+                emit(OP_STORE_GLOBAL, StrName(name).index, line);
+                break;
+            case NAME_GLOBAL_UNKNOWN:
+                emit(OP_STORE_NAME, StrName(name).index, line);
+                break;
+            default: FATAL_ERROR(); break;
+        }
+    }
+
 
     void NameExpr::emit(CodeEmitContext* ctx) {
         int index = ctx->co->varnames_inv.try_get(name);
@@ -127,22 +143,11 @@ namespace pkpy{
 
     bool NameExpr::emit_store(CodeEmitContext* ctx) {
         if(ctx->is_compiling_class){
-            int index = StrName(name).index;
+            int index = name.index;
             ctx->emit(OP_STORE_CLASS_ATTR, index, line);
             return true;
         }
-        switch(scope){
-            case NAME_LOCAL:
-                ctx->emit(OP_STORE_FAST, ctx->add_varname(name), line);
-                break;
-            case NAME_GLOBAL:
-                ctx->emit(OP_STORE_GLOBAL, StrName(name).index, line);
-                break;
-            case NAME_GLOBAL_UNKNOWN:
-                ctx->emit(OP_STORE_NAME, StrName(name).index, line);
-                break;
-            default: FATAL_ERROR(); break;
-        }
+        ctx->emit_store_name(scope, name, line);
         return true;
     }
 

+ 10 - 2
tests/30_import.py

@@ -17,5 +17,13 @@ from test2.utils import get_value_2
 assert get_value_2() == '123'
 
 from test3.a.b import value
-# should test3
-assert value == 1
+assert value == 1
+
+def f():
+    import math as m
+    assert m.pi > 3
+
+    from test3.a.b import value
+    assert value == 1
+
+f()