diff --git a/examples/basics/pattern_matching.own b/examples/basics/pattern_matching.own index 7a0ff50..cc9b118 100644 --- a/examples/basics/pattern_matching.own +++ b/examples/basics/pattern_matching.own @@ -36,7 +36,7 @@ println printArrayRecursive([1, 2, 3, 4, 5, 6, 7]) def printArrayRecursive(arr) = match arr { case [head :: tail]: "[" + head + ", " + printArrayRecursive(tail) + "]" case []: "[]" - case last: "[" + last + ", []]" + case other: "[" + other + ", []]" } println "\nPattern matching on arrays by value" diff --git a/ownlang-parser/src/main/java/com/annimon/ownlang/parser/ast/MatchExpression.java b/ownlang-parser/src/main/java/com/annimon/ownlang/parser/ast/MatchExpression.java index d1602c0..2a5e43d 100644 --- a/ownlang-parser/src/main/java/com/annimon/ownlang/parser/ast/MatchExpression.java +++ b/ownlang-parser/src/main/java/com/annimon/ownlang/parser/ast/MatchExpression.java @@ -90,10 +90,14 @@ public final class MatchExpression extends InterruptableNode implements Expressi case 0: // match [] { case []: ... } return (arraySize == 0) && optMatches(p); - case 1: // match arr { case [x]: x = arr ... } - final String variable = parts.get(0); - ScopeHandler.defineVariableInCurrentScope(variable, array); - return optMatches(p); + case 1: // match arr { case [x]: x = arr[0] ... } + if (arraySize == 1) { + final String variable = parts.get(0); + final var value = array.get(0); + ScopeHandler.defineVariableInCurrentScope(variable, value); + return optMatches(p); + } + return false; default: { // match arr { case [...]: .. } if (partsSize == arraySize) { diff --git a/ownlang-parser/src/test/resources/expressions/matchExpression.own b/ownlang-parser/src/test/resources/expressions/matchExpression.own index 2e1bcb7..5a5a1f3 100644 --- a/ownlang-parser/src/test/resources/expressions/matchExpression.own +++ b/ownlang-parser/src/test/resources/expressions/matchExpression.own @@ -43,7 +43,8 @@ def testMatchAdditionalCheckScope() { def printArrayRecursive(arr) = match arr { case [head :: tail]: "[" + head + ", " + printArrayRecursive(tail) + "]" case []: "[]" - case last: "[" + last + ", []]" + case [last]: "[" + last + ", []]" + case value: value } def testMatchEmptyArray() { @@ -53,17 +54,17 @@ def testMatchEmptyArray() { def testMatchOneElementArray() { result = printArrayRecursive([1]) - assertEquals("[[1], []]", result) + assertEquals("[1, []]", result) } def testMatchTwoElementsArray() { result = printArrayRecursive([1, 2]) - assertEquals("[1, [2, []]]", result) + assertEquals("[1, 2]", result) } def testMatchArray() { result = printArrayRecursive([1, 2, 3, 4]) - assertEquals("[1, [2, [3, [4, []]]]]", result) + assertEquals("[1, [2, [3, 4]]]", result) } def testMatchArray2() { @@ -73,9 +74,9 @@ def testMatchArray2() { case [a :: b :: c]: 3 case [a :: b]: 2 case (7): -7 // special case 1 - case [a] if a == [8]: -8 // special case 2 - case []: 0 + case [a] if a == 8: -8 // special case 2 case [a]: 1 + case []: 0 } assertEquals(4, elementsCount([1, 2, 3, 4])) assertEquals(3, elementsCount([1, 2, 3])) @@ -86,6 +87,16 @@ def testMatchArray2() { assertEquals(0, elementsCount([])) } +def testMatchArray3() { + def elementD(arr) = match arr { + case [a :: b :: c :: d]: d + case _: [] + } + assertEquals(4, elementD([1, 2, 3, 4])) + assertEquals([4, 5, 6], elementD([1, 2, 3, 4, 5, 6])) + assertEquals([], elementD([1, 2])) +} + def testMatchOneElementArrayScope() { head = 100 tail = 200 @@ -102,16 +113,16 @@ def testMatchOneElementArrayScope() { def testMatchOneElementArrayDefinedVariableScope() { head = 100 tail = 200 - last = 300 + rest = 300 result = match [1] { case [head :: tail]: fail("Multi-array") case []: fail("Empty array") - case last: fail("Array should not be equal " + last) - case rest: assertEquals(1, rest[0]) + case rest: fail("Array should not be equal " + rest) + case [last]: assertEquals(1, last) } assertEquals(100, head) assertEquals(200, tail) - assertEquals(300, last) + assertEquals(300, rest) assertEquals(true, result) } @@ -121,7 +132,7 @@ def testMatchArrayScope() { result = match [1, 2, 3] { case [head :: tail]: assertEquals(1, head) case []: fail("Empty array") - case last: fail("One element") + case [last]: fail("One element") } assertEquals(100, head) assertEquals(200, tail) @@ -173,7 +184,31 @@ def testMatchTupleAny3() { assertEquals("_", result) } -def testScope() { - +def testDestructuringArray() { + parsedData = [ + ["Kyiv", 839, 3017000, "Ukraine", "...", "..."], + ["Shebekino", "N/A", "invalid"], + ["New York", 783.8, 18937000, "USA", "..."], + ["N/A"], + [] + ] + cities = [] + areas = [] + for row : parsedData { + match row { + // Match fully parsed data + case [name :: area :: population :: country]: { + cities ::= name + areas ::= area + } + // Match partially parsed data, which contains a city name and some other unknown values + case [name :: rest]: { + cities ::= name + } + // Match other invalid data + case arr: /* skip */ 0 + } + } + assertEquals(["Kyiv", "Shebekino", "New York"], cities) + assertEquals([839, 783.8], areas) } -