Kaynağa Gözat

add compile time func

blueloveTH 8 ay önce
ebeveyn
işleme
57cd40da6f

+ 1 - 0
include/pocketpy/interpreter/vm.h

@@ -51,6 +51,7 @@ typedef struct VM {
     void* ctx;         // user-defined context
     void* ctx;         // user-defined context
 
 
     BinTree cached_names;
     BinTree cached_names;
+    NameDict compile_time_funcs;
 
 
     py_StackRef curr_class;
     py_StackRef curr_class;
     py_StackRef curr_decl_based_function;
     py_StackRef curr_decl_based_function;

+ 5 - 0
include/pocketpy/pocketpy.h

@@ -129,6 +129,11 @@ PK_API void py_watchdog_begin(py_i64 timeout);
 /// Reset the watchdog.
 /// Reset the watchdog.
 PK_API void py_watchdog_end();
 PK_API void py_watchdog_end();
 
 
+/// Bind a compile-time function via "decl-based" style.
+PK_API void py_compiletime_bind(const char* sig, py_CFunction f);
+/// Find a compile-time function by name.
+PK_API py_ItemRef py_compiletime_getfunc(py_Name name);
+
 /// Get the current source location of the frame.
 /// Get the current source location of the frame.
 PK_API const char* py_Frame_sourceloc(py_Frame* frame, int* lineno);
 PK_API const char* py_Frame_sourceloc(py_Frame* frame, int* lineno);
 /// Python equivalent to `globals()` with respect to the given frame.
 /// Python equivalent to `globals()` with respect to the given frame.

+ 81 - 1
src/compiler/compiler.c

@@ -368,6 +368,25 @@ Literal0Expr* Literal0Expr__new(int line, TokenIndex token) {
     return self;
     return self;
 }
 }
 
 
+typedef struct LoadConstExpr {
+    EXPR_COMMON_HEADER
+    int index;
+} LoadConstExpr;
+
+void LoadConstExpr__emit_(Expr* self_, Ctx* ctx) {
+    LoadConstExpr* self = (LoadConstExpr*)self_;
+    Ctx__emit_(ctx, OP_LOAD_CONST, self->index, self->line);
+}
+
+LoadConstExpr* LoadConstExpr__new(int line, int index) {
+    const static ExprVt Vt = {.emit_ = LoadConstExpr__emit_};
+    LoadConstExpr* self = PK_MALLOC(sizeof(LoadConstExpr));
+    self->vt = &Vt;
+    self->line = line;
+    self->index = index;
+    return self;
+}
+
 typedef struct SliceExpr {
 typedef struct SliceExpr {
     EXPR_COMMON_HEADER
     EXPR_COMMON_HEADER
     Expr* start;
     Expr* start;
@@ -1864,9 +1883,70 @@ static Error* exprMap(Compiler* self) {
     return NULL;
     return NULL;
 }
 }
 
 
+static Error* read_literal(Compiler* self, py_Ref out);
+
+static Error* exprCompileTimeCall(Compiler* self, py_ItemRef func, int line) {
+    Error* err;
+    py_push(func);
+    py_pushnil();
+
+    uint16_t argc = 0;
+    uint16_t kwargc = 0;
+    // copied from `exprCall`
+    do {
+        match_newlines();
+        if(curr()->type == TK_RPAREN) break;
+        if(curr()->type == TK_ID && next()->type == TK_ASSIGN) {
+            consume(TK_ID);
+            py_Name key = py_namev(Token__sv(prev()));
+            consume(TK_ASSIGN);
+            // k=v
+            py_pushname(key);
+            check(read_literal(self, py_pushtmp()));
+            kwargc += 1;
+        } else {
+            if(kwargc > 0) {
+                return SyntaxError(self, "positional argument follows keyword argument");
+            }
+            check(read_literal(self, py_pushtmp()));
+            argc += 1;
+        }
+        match_newlines();
+    } while(match(TK_COMMA));
+    consume(TK_RPAREN);
+
+    bool ok = py_vectorcall(argc, kwargc);
+    if(!ok) {
+        char* msg = py_formatexc();
+        err = SyntaxError(self, "compile-time call error:\n%s", msg);
+        PK_FREE(msg);
+        return err;
+    }
+
+    // TODO: optimize string dedup
+    int index = Ctx__add_const(ctx(), py_retval());
+    Ctx__s_push(ctx(), (Expr*)LoadConstExpr__new(line, index));
+    return NULL;
+}
+
 static Error* exprCall(Compiler* self) {
 static Error* exprCall(Compiler* self) {
     Error* err;
     Error* err;
-    CallExpr* e = CallExpr__new(prev()->line, Ctx__s_popx(ctx()));
+    Expr* callable = Ctx__s_popx(ctx());
+    int line = prev()->line;
+    if(callable->vt->is_name) {
+        NameExpr* ne = (NameExpr*)callable;
+        if(ne->scope == NAME_GLOBAL) {
+            py_ItemRef func = py_compiletime_getfunc(ne->name);
+            if(func != NULL) {
+                py_StackRef p0 = py_peek(0);
+                err = exprCompileTimeCall(self, func, line);
+                if(err != NULL) py_clearexc(p0);
+                return err;
+            }
+        }
+    }
+
+    CallExpr* e = CallExpr__new(line, callable);
     Ctx__s_push(ctx(), (Expr*)e);  // push onto the stack in advance
     Ctx__s_push(ctx(), (Expr*)e);  // push onto the stack in advance
     do {
     do {
         match_newlines();
         match_newlines();

+ 8 - 0
src/interpreter/vm.c

@@ -116,6 +116,7 @@ void VM__ctor(VM* self) {
         .need_free_key = false,
         .need_free_key = false,
     };
     };
     BinTree__ctor(&self->cached_names, NULL, py_NIL(), &cached_names_config);
     BinTree__ctor(&self->cached_names, NULL, py_NIL(), &cached_names_config);
+    NameDict__ctor(&self->compile_time_funcs, PK_TYPE_ATTR_LOAD_FACTOR);
 
 
     /* Init Builtin Types */
     /* Init Builtin Types */
     // 0: unused
     // 0: unused
@@ -294,6 +295,7 @@ void VM__dtor(VM* self) {
     FixedMemoryPool__dtor(&self->pool_frame);
     FixedMemoryPool__dtor(&self->pool_frame);
     ValueStack__dtor(&self->stack);
     ValueStack__dtor(&self->stack);
     BinTree__dtor(&self->cached_names);
     BinTree__dtor(&self->cached_names);
+    NameDict__dtor(&self->compile_time_funcs);
 }
 }
 
 
 void VM__push_frame(VM* self, py_Frame* frame) {
 void VM__push_frame(VM* self, py_Frame* frame) {
@@ -673,6 +675,12 @@ void ManagedHeap__mark(ManagedHeap* self) {
     BinTree__apply_mark(&vm->modules, p_stack);
     BinTree__apply_mark(&vm->modules, p_stack);
     // mark cached names
     // mark cached names
     BinTree__apply_mark(&vm->cached_names, p_stack);
     BinTree__apply_mark(&vm->cached_names, p_stack);
+    // mark compile time functions
+    for(int i = 0; i < vm->compile_time_funcs.capacity; i++) {
+        NameDict_KV* kv = &vm->compile_time_funcs.items[i];
+        if(kv->key == NULL) continue;
+        pk__mark_value(&kv->value);
+    }
     // mark types
     // mark types
     int types_length = vm->types.length;
     int types_length = vm->types.length;
     // 0-th type is placeholder
     // 0-th type is placeholder

+ 17 - 3
src/public/values.c

@@ -89,9 +89,23 @@ void py_bindmagic(py_Type type, py_Name name, py_CFunction f) {
 }
 }
 
 
 void py_bind(py_Ref obj, const char* sig, py_CFunction f) {
 void py_bind(py_Ref obj, const char* sig, py_CFunction f) {
-    py_TValue tmp;
-    py_Name name = py_newfunction(&tmp, sig, f, NULL, 0);
-    py_setdict(obj, name, &tmp);
+    py_Ref tmp = py_pushtmp();
+    py_Name name = py_newfunction(tmp, sig, f, NULL, 0);
+    py_setdict(obj, name, tmp);
+    py_pop();
+}
+
+void py_compiletime_bind(const char* sig, py_CFunction f) {
+    py_Ref tmp = py_pushtmp();
+    py_Name name = py_newfunction(tmp, sig, f, NULL, 0);
+    NameDict__set(&pk_current_vm->compile_time_funcs, name, tmp);
+    py_pop();
+}
+
+PK_API py_ItemRef py_compiletime_getfunc(py_Name name) {
+    NameDict* d = &pk_current_vm->compile_time_funcs;
+    if(d->length == 0) return NULL;
+    return NameDict__try_get(d, name);
 }
 }
 
 
 py_Name py_newfunction(py_OutRef out,
 py_Name py_newfunction(py_OutRef out,

+ 1 - 1
src2/main.c

@@ -53,7 +53,7 @@ int main(int argc, char** argv) {
     py_initialize();
     py_initialize();
     py_sys_setargv(argc, argv);
     py_sys_setargv(argc, argv);
 
 
-    assert(!profile);   // not implemented yet
+    assert(!profile);  // not implemented yet
     // if(profile) py_sys_settrace(LineProfiler__tracefunc, true);
     // if(profile) py_sys_settrace(LineProfiler__tracefunc, true);
 
 
     if(filename == NULL) {
     if(filename == NULL) {