diff --git a/examples/basics/pattern_matching.own b/examples/basics/pattern_matching.own index 4805fa1..7a0ff50 100644 --- a/examples/basics/pattern_matching.own +++ b/examples/basics/pattern_matching.own @@ -38,3 +38,28 @@ def printArrayRecursive(arr) = match arr { case []: "[]" case last: "[" + last + ", []]" } + +println "\nPattern matching on arrays by value" +def tupleMatch(x) = match x { + case (0, 0): "00" + case (1, 0): "10" + case (0, 1): "01" + case (1, 1): "11" + case (2, _): "2?" + case _: "unknown" +} +println tupleMatch([0, 1]) +println tupleMatch([1, 1]) +println tupleMatch([2, 1]) +println tupleMatch([3, 9]) + + +println "\nFizzBuzz with pattern matching" +for i = 1, i <= 100, i++ { + println match [i % 3 == 0, i % 5 == 0] { + case (true, false): "Fizz" + case (false, true): "Buzz" + case (true, true): "FizzBuzz" + case _: i + } +} \ No newline at end of file diff --git a/src/com/annimon/ownlang/parser/Parser.java b/src/com/annimon/ownlang/parser/Parser.java index 969dd64..8e85633 100644 --- a/src/com/annimon/ownlang/parser/Parser.java +++ b/src/com/annimon/ownlang/parser/Parser.java @@ -368,6 +368,19 @@ public final class Parser { match(TokenType.COLONCOLON); } pattern = listPattern; + } else if (match(TokenType.LPAREN)) { + // case (1, 2): + final MatchExpression.TuplePattern tuplePattern = new MatchExpression.TuplePattern(); + while (!match(TokenType.RPAREN)) { + if ("_".equals(get(0).getText())) { + tuplePattern.addAny(); + consume(TokenType.WORD); + } else { + tuplePattern.add(expression()); + } + match(TokenType.COMMA); + } + pattern = tuplePattern; } if (pattern == null) { diff --git a/src/com/annimon/ownlang/parser/ast/MatchExpression.java b/src/com/annimon/ownlang/parser/ast/MatchExpression.java index e21a8aa..3eb477c 100644 --- a/src/com/annimon/ownlang/parser/ast/MatchExpression.java +++ b/src/com/annimon/ownlang/parser/ast/MatchExpression.java @@ -7,6 +7,7 @@ import com.annimon.ownlang.lib.Types; import com.annimon.ownlang.lib.Value; import com.annimon.ownlang.lib.Variables; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; /** @@ -14,7 +15,7 @@ import java.util.List; * @author aNNiMON */ public final class MatchExpression implements Expression, Statement { - + public final Expression expression; public final List patterns; @@ -22,12 +23,12 @@ public final class MatchExpression implements Expression, Statement { this.expression = expression; this.patterns = patterns; } - + @Override public void execute() { eval(); } - + @Override public Value eval() { final Value value = expression.eval(); @@ -41,7 +42,7 @@ public final class MatchExpression implements Expression, Statement { if (p instanceof VariablePattern) { final VariablePattern pattern = (VariablePattern) p; if (pattern.variable.equals("_")) return evalResult(p.result); - + if (Variables.isExists(pattern.variable)) { if (match(value, Variables.get(pattern.variable)) && optMatches(p)) { return evalResult(p.result); @@ -49,7 +50,7 @@ public final class MatchExpression implements Expression, Statement { } else { Variables.define(pattern.variable, value); if (optMatches(p)) { - final Value result = evalResult(p.result);; + final Value result = evalResult(p.result); Variables.remove(pattern.variable); return result; } @@ -67,10 +68,29 @@ public final class MatchExpression implements Expression, Statement { return result; } } + if ((value.type() == Types.ARRAY) && (p instanceof TuplePattern)) { + final TuplePattern pattern = (TuplePattern) p; + if (matchTuplePattern((ArrayValue) value, pattern) && optMatches(p)) { + return evalResult(p.result); + } + } } throw new PatternMatchingException("No pattern were matched"); } - + + private boolean matchTuplePattern(ArrayValue array, TuplePattern p) { + if (p.values.size() != array.size()) return false; + + final int size = array.size(); + for (int i = 0; i < size; i++) { + final Expression expr = p.values.get(i); + if ( (expr != TuplePattern.ANY) && (expr.eval().compareTo(array.get(i)) != 0) ) { + return false; + } + } + return true; + } + private boolean matchListPattern(ArrayValue array, ListPattern p) { final List parts = p.parts; final int partsSize = parts.size(); @@ -90,7 +110,7 @@ public final class MatchExpression implements Expression, Statement { } Variables.remove(variable); return false; - + default: { // match arr { case [...]: .. } if (partsSize == arraySize) { // match [0, 1, 2] { case [a::b::c]: a=0, b=1, c=2 ... } @@ -102,7 +122,7 @@ public final class MatchExpression implements Expression, Statement { return false; } } - } + } private boolean matchListPatternEqualsSize(ListPattern p, List parts, int partsSize, ArrayValue array) { // Set variables @@ -119,7 +139,7 @@ public final class MatchExpression implements Expression, Statement { } return false; } - + private boolean matchListPatternWithTail(ListPattern p, List parts, int partsSize, ArrayValue array, int arraySize) { // Set element variables final int lastPart = partsSize - 1; @@ -143,17 +163,17 @@ public final class MatchExpression implements Expression, Statement { } return false; } - + private boolean match(Value value, Value constant) { if (value.type() != constant.type()) return false; return value.equals(constant); } - + private boolean optMatches(Pattern pattern) { if (pattern.optCondition == null) return true; return pattern.optCondition.eval() != NumberValue.ZERO; } - + private Value evalResult(Statement s) { try { s.execute(); @@ -162,7 +182,7 @@ public final class MatchExpression implements Expression, Statement { } return NumberValue.ZERO; } - + @Override public void accept(Visitor visitor) { visitor.visit(this); @@ -183,15 +203,15 @@ public final class MatchExpression implements Expression, Statement { sb.append("\n}"); return sb.toString(); } - + public static abstract class Pattern { public Statement result; public Expression optCondition; } - + public static class ConstantPattern extends Pattern { public Value constant; - + public ConstantPattern(Value pattern) { this.constant = pattern; } @@ -201,10 +221,10 @@ public final class MatchExpression implements Expression, Statement { return constant + ": " + result; } } - + public static class VariablePattern extends Pattern { public String variable; - + public VariablePattern(String pattern) { this.variable = pattern; } @@ -214,14 +234,14 @@ public final class MatchExpression implements Expression, Statement { return variable + ": " + result; } } - + public static class ListPattern extends Pattern { public List parts; - + public ListPattern() { this(new ArrayList()); } - + public ListPattern(List parts) { this.parts = parts; } @@ -232,7 +252,73 @@ public final class MatchExpression implements Expression, Statement { @Override public String toString() { - return parts + ": " + result; + final Iterator it = parts.iterator(); + if (it.hasNext()) { + final StringBuilder sb = new StringBuilder(); + sb.append("[").append(it.next()); + while (it.hasNext()) { + sb.append(" :: ").append(it.next()); + } + sb.append("]: ").append(result); + return sb.toString(); + } + return "[]: " + result; } } + + public static class TuplePattern extends Pattern { + public List values; + + public TuplePattern() { + this(new ArrayList()); + } + + public TuplePattern(List parts) { + this.values = parts; + } + + public void addAny() { + values.add(ANY); + } + + public void add(Expression value) { + values.add(value); + } + + @Override + public String toString() { + final Iterator it = values.iterator(); + if (it.hasNext()) { + final StringBuilder sb = new StringBuilder(); + sb.append("(").append(it.next()); + while (it.hasNext()) { + sb.append(", ").append(it.next()); + } + sb.append("): ").append(result); + return sb.toString(); + } + return "(): " + result; + } + + private static final Expression ANY = new Expression() { + @Override + public Value eval() { + return NumberValue.ONE; + } + + @Override + public void accept(Visitor visitor) { + } + + @Override + public R accept(ResultVisitor visitor, T input) { + return null; + } + + @Override + public String toString() { + return "_"; + } + }; + } }