diff --git a/lval/environment.c b/lval/environment.c index 5ebc75f..133b39d 100644 --- a/lval/environment.c +++ b/lval/environment.c @@ -6,6 +6,7 @@ lenv* lenv_new(void) { lenv* e = (lenv*) malloc(sizeof(lenv)); + e->par = NULL; e->count = 0; e->syms = NULL; e->vals = NULL; @@ -23,12 +24,13 @@ void lenv_del(lenv* e) { } lenv* lenv_copy(lenv* e) { - lenv* n = malloc(sizeof(lenv)); + lenv* n = (lenv*) malloc(sizeof(lenv)); + n->par = e->par; n->count = e->count; - n->syms = malloc(sizeof(char*) * n->count); - n->vals = malloc(sizeof(lval*) * n->count); + n->syms = (char**) malloc(sizeof(char*) * n->count); + n->vals = (lval**) malloc(sizeof(lval*) * n->count); for (int i = 0; i < e->count; i++) { - n->syms[i] = malloc(strlen(e->syms[i]) + 1); + n->syms[i] = (char*) malloc(strlen(e->syms[i]) + 1); strcpy(n->syms[i], e->syms[i]); n->vals[i] = lval_copy(e->vals[i]); } @@ -45,7 +47,12 @@ lval* lenv_get(lenv* e, lval* k) { } } - // If no symbol found return error + // If no symbol found so far, check in parent + if (e->par) { + return lenv_get(e->par, k); + } + + // If no symbol found and no parent, return error return lval_err("Unbounded symbol %s", k->sym); } @@ -73,16 +80,24 @@ void lenv_put(lenv* e, lval* k, lval* v) { strcpy(e->syms[e->count - 1], k->sym); } -lval* lval_fun(lbuiltin func) { - lval* v = malloc(sizeof(lval)); - v->type = LVAL_FUN; - v->builtin = func; - return v; +void lenv_def(lenv* e, lval* k, lval* v) { + // Iterate until e has no parent + while (e->par) { e = e->par; } + + // Put the value in e + lenv_put(e, k, v); +} + +lval* lval_builtin(lbuiltin func) { + lval* v = malloc(sizeof(lval)); + v->type = LVAL_FUN; + v->builtin = func; + return v; } void lenv_add_builtin(lenv* e, char* name, lbuiltin func) { lval* k = lval_sym(name); - lval* v = lval_fun(func); + lval* v = lval_builtin(func); lenv_put(e, k, v); lval_del(k); lval_del(v); } @@ -108,35 +123,11 @@ void lenv_add_builtins(lenv* e) { lenv_add_builtin(e, "max", builtin_max); lenv_add_builtin(e, "def", builtin_def); + lenv_add_builtin(e, "=", builtin_put); lenv_add_builtin(e, "ls", builtin_ls); lenv_add_builtin(e, "\\", builtin_lambda); } -lval* builtin_def(lenv* e, lval* a) { - LASSERT_TYPE("def", a, 0, LVAL_QEXPR) - - // First argument is the symbol list - lval* syms = a->cell[0]; - - // Ensure all elements of the first list are symbols - for (int i = 0; i < syms->count; i++) { - LASSERT(a, syms->cell[i]->type == LVAL_SYM, - "Function 'def' cannot define non-symbol") - } - - // Check correct number of symbols and values - LASSERT(a, syms->count == a->count - 1, - "Function 'def' cannot define incorrect number of values to symbols. Left side %i, right side %i", syms->count, a->count - 1) - - // Assign copies of values to symbols - for (int i = 0; i < syms->count; i++) { - lenv_put(e, syms->cell[i], a->cell[i + 1]); - } - - lval_del(a); - return lval_sexpr(); -} - lval* builtin_ls(lenv* e, lval* a) { LASSERT_NUM("ls", a, 0) @@ -184,4 +175,90 @@ lval* builtin_lambda(lenv* e, lval* a) { lval_del(a); return lval_lambda(formals, body); +} + +lval* builtin_var(lenv* e, lval* a, char* func) { + LASSERT_TYPE(func, a, 0, LVAL_QEXPR) + + lval* syms = a->cell[0]; + for (int i = 0; i < syms->count; i++) { + LASSERT(a, (syms->cell[i]->type == LVAL_SYM), + "Function '%s' cannot define non-symbol. " + "Got %s, Expected %s.", func, + ltype_name(syms->cell[i]->type), + ltype_name(LVAL_SYM)) + } + + LASSERT(a, (syms->count == a->count-1), + "Function '%s' passed too many arguments for symbols. " + "Got %i, Expected %i.", func, syms->count, a->count-1) + + for (int i = 0; i < syms->count; i++) { + // If 'def' define it globally + if (strcmp(func, "def") == 0) { + lenv_def(e, syms->cell[i], a->cell[i + 1]); + } + // If 'put' define it locally + if (strcmp(func, "=") == 0) { + lenv_put(e, syms->cell[i], a->cell[i + 1]); + } + } + + lval_del(a); + return lval_sexpr(); +} + +lval* builtin_def(lenv* e, lval* a) { + return builtin_var(e, a, "def"); +} +lval* builtin_put(lenv* e, lval* a) { + return builtin_var(e, a, "="); +} + +lval* lval_call(lenv* e, lval* f, lval* a) { + // If builtin simply apply that + if (f->builtin) { return f->builtin(e, a); } + + // Record argument counts + int given = a->count; + int total = f->formals->count; + + // // While arguments still remain to be processed + while (a->count) { + // If we've run out of formal arguments.. + if (f->formals->count == 0) { + lval_del(a); + return lval_err("Function passed too many arguments. " + "Got %i, Expected %i.", given, total); + } + + // Pop the first symbol from the formals + lval* sym = lval_pop(f->formals, 0); + + // Pop the next argument from the list + lval* val = lval_pop(a, 0); + + // Bind a copy into the function's environment + lenv_put(f->env, sym, val); + + // Delete the symbol and value + lval_del(sym); lval_del(val); + } + + // The argument list is now bounded so we can clean up the given + lval_del(a); + + // If all formals have been bounded evaluate + if (f->formals->count == 0) { + // Set environment parent to evaluation environment + f->env->par = e; + + // Evaluate and return + return builtin_eval( + f->env, lval_add(lval_sexpr(), lval_copy(f->body))); + } else { + // Otherwise return partially evaluated function + return lval_copy(f); + } + return lval_sexpr(); } \ No newline at end of file diff --git a/lval/environment.h b/lval/environment.h index 815ecbe..816cf8c 100644 --- a/lval/environment.h +++ b/lval/environment.h @@ -3,6 +3,8 @@ #include "base.h" struct lenv { + // Represents the parent environment + lenv* par; int count; char** syms; lval** vals; @@ -10,13 +12,16 @@ struct lenv { lenv* lenv_new(void); void lenv_del(lenv* e); +lenv* lenv_copy(lenv* e); + // Obtain a variable from the environment // e is the environment // k is the symbol lval* lenv_get(lenv* e, lval* k); void lenv_put(lenv* e, lval* k, lval* v); +void lenv_def(lenv* e, lval* k, lval* v); -lval* lval_fun(lbuiltin func); +lval* lval_builtin(lbuiltin func); void lenv_add_builtin(lenv* e, char* name, lbuiltin func); void lenv_add_builtins(lenv* e); @@ -24,6 +29,10 @@ void lenv_add_builtins(lenv* e); lval* builtin_def(lenv* e, lval* a); lval* builtin_ls(lenv* e, lval* a); lval* builtin_lambda(lenv* e, lval* a); +lval* builtin_def(lenv* e, lval* a); +lval* builtin_put(lenv* e, lval* a); +lval* builtin_var(lenv* e, lval* a, char* func); +lval* lval_call(lenv* e, lval* f, lval* a); #endif diff --git a/lval/expressions.c b/lval/expressions.c index e22f101..a84d2de 100644 --- a/lval/expressions.c +++ b/lval/expressions.c @@ -3,6 +3,7 @@ #include #include "expressions.h" +#include "environment.h" #include "numbers.h" #include "operations.h" #include "error.h" @@ -85,17 +86,20 @@ lval* lval_eval_sexpr(lenv* e, lval* v) { if (v->count == 0) { return v; } // Single expression - if (v->count == 1) { return lval_take(v, 0); } + if (v->count == 1) { return lval_eval(e, lval_take(v, 0)); } // Ensure first element is a symbol otherwise lval* f = lval_pop(v, 0); if (f->type != LVAL_FUN) { + lval* err = lval_err( + "S-Experssion starts with incorrect type. " + "Got %s, Expected %s.", + ltype_name(f->type), ltype_name(LVAL_FUN)); lval_del(f); lval_del(v); - return lval_err("S-expression does not start with function"); + return err; } - // If so call the function and return result - lval* result = f->builtin(e, v); + lval* result = lval_call(e, f, v); lval_del(f); return result; } diff --git a/lval/operations.c b/lval/operations.c index 0e6e3f6..631054c 100644 --- a/lval/operations.c +++ b/lval/operations.c @@ -3,6 +3,7 @@ #include "numbers.h" #include "expressions.h" #include "operations.h" +#include "environment.h" lval* builtin_op(lenv* e, lval* v, char* sym);