Fix scope of foreach and match statements

v = 100
for v : [1] ..
v should be 1, not 100
This commit is contained in:
aNNiMON 2023-09-07 18:40:28 +03:00 committed by Victor Melnik
parent b7c376f01f
commit 32bffaee86
11 changed files with 245 additions and 76 deletions

View File

@ -0,0 +1,8 @@
package com.annimon.ownlang.lib;
public final class AutoCloseableScope implements AutoCloseable {
@Override
public void close() {
ScopeHandler.pop();
}
}

View File

@ -36,6 +36,11 @@ public final class ScopeHandler {
scope = rootScope;
}
public static AutoCloseableScope closeableScope() {
push();
return new AutoCloseableScope();
}
public static void push() {
synchronized (lock) {
scope = new Scope(scope);

View File

@ -22,7 +22,7 @@ public final class Parser {
final Parser parser = new Parser(tokens);
final Statement program = parser.parse();
if (parser.getParseErrors().hasErrors()) {
throw new ParseException();
throw new ParseException(parser.getParseErrors().toString());
}
return program;
}

View File

@ -23,22 +23,14 @@ public final class ForeachArrayStatement extends InterruptableNode implements St
@Override
public void execute() {
super.interruptionCheck();
// TODO removing without checking shadowing is dangerous
final Value previousVariableValue = ScopeHandler.getVariable(variable);
final Value containerValue = container.eval();
switch (containerValue.type()) {
case Types.STRING -> iterateString(containerValue.asString());
case Types.ARRAY -> iterateArray((ArrayValue) containerValue);
case Types.MAP -> iterateMap((MapValue) containerValue);
default -> throw new TypeException("Cannot iterate " + Types.typeToString(containerValue.type()));
}
// Restore variables
if (previousVariableValue != null) {
ScopeHandler.setVariable(variable, previousVariableValue);
} else {
ScopeHandler.removeVariable(variable);
try (final var ignored = ScopeHandler.closeableScope()) {
final Value containerValue = container.eval();
switch (containerValue.type()) {
case Types.STRING -> iterateString(containerValue.asString());
case Types.ARRAY -> iterateArray((ArrayValue) containerValue);
case Types.MAP -> iterateMap((MapValue) containerValue);
default -> throw new TypeException("Cannot iterate " + Types.typeToString(containerValue.type()));
}
}
}

View File

@ -24,28 +24,14 @@ public final class ForeachMapStatement extends InterruptableNode implements Stat
@Override
public void execute() {
super.interruptionCheck();
// TODO removing without checking shadowing is dangerous
final Value previousVariableValue1 = ScopeHandler.getVariable(key);
final Value previousVariableValue2 = ScopeHandler.getVariable(value);
final Value containerValue = container.eval();
switch (containerValue.type()) {
case Types.STRING -> iterateString(containerValue.asString());
case Types.ARRAY -> iterateArray((ArrayValue) containerValue);
case Types.MAP -> iterateMap((MapValue) containerValue);
default -> throw new TypeException("Cannot iterate " + Types.typeToString(containerValue.type()) + " as key, value pair");
}
// Restore variables
if (previousVariableValue1 != null) {
ScopeHandler.setVariable(key, previousVariableValue1);
} else {
ScopeHandler.removeVariable(key);
}
if (previousVariableValue2 != null) {
ScopeHandler.setVariable(value, previousVariableValue2);
} else {
ScopeHandler.removeVariable(value);
try (final var ignored = ScopeHandler.closeableScope()) {
final Value containerValue = container.eval();
switch (containerValue.type()) {
case Types.STRING -> iterateString(containerValue.asString());
case Types.ARRAY -> iterateArray((ArrayValue) containerValue);
case Types.MAP -> iterateMap((MapValue) containerValue);
default -> throw new TypeException("Cannot iterate " + Types.typeToString(containerValue.type()) + " as key, value pair");
}
}
}

View File

@ -54,13 +54,10 @@ public final class MatchExpression extends InterruptableNode implements Expressi
}
}
if ((value.type() == Types.ARRAY) && (p instanceof ListPattern pattern)) {
if (matchListPattern((ArrayValue) value, pattern)) {
// Clean up variables if matched
final Value result = evalResult(p.result);
for (String var : pattern.parts) {
ScopeHandler.removeVariable(var);
try (final var ignored = ScopeHandler.closeableScope()) {
if (matchListPattern((ArrayValue) value, pattern)) {
return evalResult(p.result);
}
return result;
}
}
if ((value.type() == Types.ARRAY) && (p instanceof TuplePattern pattern)) {
@ -91,20 +88,12 @@ public final class MatchExpression extends InterruptableNode implements Expressi
final int arraySize = array.size();
switch (partsSize) {
case 0: // match [] { case []: ... }
if ((arraySize == 0) && optMatches(p)) {
return true;
}
return false;
return (arraySize == 0) && optMatches(p);
case 1: // match arr { case [x]: x = arr ... }
final String variable = parts.get(0);
ScopeHandler.defineVariableInCurrentScope(variable, array);
if (optMatches(p)) {
return true;
}
// TODO remove is dangerous
ScopeHandler.removeVariable(variable);
return false;
return optMatches(p);
default: { // match arr { case [...]: .. }
if (partsSize == arraySize) {
@ -124,16 +113,7 @@ public final class MatchExpression extends InterruptableNode implements Expressi
for (int i = 0; i < partsSize; i++) {
ScopeHandler.defineVariableInCurrentScope(parts.get(i), array.get(i));
}
if (optMatches(p)) {
// Clean up will be provided after evaluate result
return true;
}
// Clean up variables if no match
for (String var : parts) {
// TODO removing without checking shadowing is dangerous
ScopeHandler.removeVariable(var);
}
return false;
return optMatches(p);
}
private boolean matchListPatternWithTail(ListPattern p, List<String> parts, int partsSize, ArrayValue array, int arraySize) {
@ -148,16 +128,7 @@ public final class MatchExpression extends InterruptableNode implements Expressi
tail.set(i - lastPart, array.get(i));
}
ScopeHandler.defineVariableInCurrentScope(parts.get(lastPart), tail);
// Check optional condition
if (optMatches(p)) {
// Clean up will be provided after evaluate result
return true;
}
// Clean up variables
for (String var : parts) {
ScopeHandler.removeVariable(var);
}
return false;
return optMatches(p);
}
private boolean match(Value value, Value constant) {

View File

@ -57,6 +57,14 @@ public class ProgramsTest {
() -> ((FunctionValue) args[0]).getValue().execute());
return NumberValue.ONE;
});
ScopeHandler.setFunction("fail", (args) -> {
if (args.length > 0) {
fail(args[0].asString());
} else {
fail();
}
return NumberValue.ONE;
});
}
@ParameterizedTest

View File

@ -1,5 +1,7 @@
package com.annimon.ownlang.parser.ast;
import com.annimon.ownlang.lib.ScopeHandler;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import com.annimon.ownlang.exceptions.VariableDoesNotExistsException;
@ -11,6 +13,11 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
* @author aNNiMON
*/
public class VariableExpressionTest {
@BeforeEach
void setUp() {
ScopeHandler.resetScope();
}
@Test
public void testVariable() {

View File

@ -28,3 +28,16 @@ def testStringIterate() {
assertEquals("ABCD", str)
assertEquals(394/*97 + 98 + 99 + 100*/, sum)
}
def testScope() {
a = 100
b = 200
sum = 0
for a, b : {14: 3} {
sum += a
sum += b
}
assertEquals(17, sum)
assertEquals(14, a)
assertEquals(3, b)
}

View File

@ -36,6 +36,6 @@ def testScope() {
sum += v
}
assertEquals(6, sum)
assertEquals(45, v)
assertEquals(3, v)
}

View File

@ -0,0 +1,179 @@
use "types"
def testMatchValue() {
value = 20
result = match value {
case 10: "ten"
case 20: "twenty"
}
assertEquals("twenty", result)
}
def testMatchValueAny() {
value = 20
result = match value {
case 0: "zero"
case 1: "one"
case _: "other"
}
assertEquals("other", result)
}
def testMatchAdditionalCheck() {
value = 20
result = match value {
case 10: "ten"
case x if x < 10: "" + x + "<10"
case x if x > 10: "" + x + ">10"
}
assertEquals("20>10", result)
}
def testMatchAdditionalCheckScope() {
x = 20
result = match x {
case 10: "ten"
case x if x < 10: fail()
case y if y > 10: assertEquals(20, y)
}
assertEquals(20, x)
assertEquals(true, result)
}
def printArrayRecursive(arr) = match arr {
case [head :: tail]: "[" + head + ", " + printArrayRecursive(tail) + "]"
case []: "[]"
case last: "[" + last + ", []]"
}
def testMatchEmptyArray() {
result = printArrayRecursive([])
assertEquals("[]", result)
}
def testMatchOneElementArray() {
result = printArrayRecursive([1])
assertEquals("[[1], []]", result)
}
def testMatchTwoElementsArray() {
result = printArrayRecursive([1, 2])
assertEquals("[1, [2, []]]", result)
}
def testMatchArray() {
result = printArrayRecursive([1, 2, 3, 4])
assertEquals("[1, [2, [3, [4, []]]]]", result)
}
def testMatchArray2() {
def elementsCount(arr) = match arr {
case [a :: b :: c :: d :: e]: 5
case [a :: b :: c :: d]: 4
case [a :: b :: c]: 3
case [a :: b]: 2
case (7): -7 // special case 1
case [a] if a == [8]: -8 // special case 2
case []: 0
case [a]: 1
}
assertEquals(4, elementsCount([1, 2, 3, 4]))
assertEquals(3, elementsCount([1, 2, 3]))
assertEquals(2, elementsCount([1, 2]))
assertEquals(1, elementsCount([1]))
assertEquals(-7, elementsCount([7]))
assertEquals(-8, elementsCount([8]))
assertEquals(0, elementsCount([]))
}
def testMatchOneElementArrayScope() {
head = 100
tail = 200
result = match [1] {
case [head :: tail]: fail("Multi-array")
case []: fail("Empty array")
case last: assertEquals(1, last[0])
}
assertEquals(100, head)
assertEquals(200, tail)
assertEquals(true, result)
}
def testMatchOneElementArrayDefinedVariableScope() {
head = 100
tail = 200
last = 300
result = match [1] {
case [head :: tail]: fail("Multi-array")
case []: fail("Empty array")
case last: fail("Array should not be equal " + last)
case rest: assertEquals(1, rest[0])
}
assertEquals(100, head)
assertEquals(200, tail)
assertEquals(300, last)
assertEquals(true, result)
}
def testMatchArrayScope() {
head = 100
tail = 200
result = match [1, 2, 3] {
case [head :: tail]: assertEquals(1, head)
case []: fail("Empty array")
case last: fail("One element")
}
assertEquals(100, head)
assertEquals(200, tail)
assertEquals(true, result)
}
def testMatchTuple() {
result = match [1, 2] {
case (0, 1): "(0, 1)"
case (1, 2): "(1, 2)"
case (2, 3): "(2, 3)"
}
assertEquals("(1, 2)", result)
}
def testMatchTupleDifferentLength() {
result = match [1, 2] {
case (1): "(1)"
case (1, 2, 3, 4): "(1, 2, 3, 4)"
case _: "not matched"
}
assertEquals("not matched", result)
}
def testMatchTupleAny1() {
result = match [1, 2] {
case (0, _): "(0, _)"
case (1, _): "(1, _)"
case (2, _): "(2, _)"
}
assertEquals("(1, _)", result)
}
def testMatchTupleAny2() {
result = match [2, 3] {
case (0, _): "(0, _)"
case (1, _): "(1, _)"
case (_, _): "(_, _)"
}
assertEquals("(_, _)", result)
}
def testMatchTupleAny3() {
result = match [2, 3] {
case (0, _): "(0, _)"
case (1, _): "(1, _)"
case _: "_"
}
assertEquals("_", result)
}
def testScope() {
}