source: trunk/CrypPlugins/CostFunction/RegEx.cs @ 2086

Last change on this file since 2086 was 2086, checked in by Sven Rech, 11 years ago

KeySearcher OpenCL implementation

File size: 24.5 KB
Line 
1/*                             
2   Copyright 2010 Sven Rech (svenrech at googlemail dot com), Uni Duisburg-Essen
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17using System;
18using System.Collections.Generic;
19using System.Linq;
20using System.Text;
21using System.Diagnostics;
22
23namespace Cryptool.Plugins.CostFunction
24{
25    /// <summary>
26    /// An implementation of fast regex matching algorithm.
27    /// Note: Only the matching is fast. Compiling the regex isn't!
28    ///
29    /// The matching algorithm works on byte arrays and not on string.
30    /// This is because the cost function also works on byte[]. So we save a conversion.
31    /// </summary>
32    internal class RegEx
33    {
34        private const int NOTRANSITION = int.MaxValue;
35        private int[][] transitionMatrix = null;
36        private int startIndex;
37
38        public string Regex
39        {
40            get;
41            private set;
42        }
43
44        public RegEx(string regex, bool caseSensitiv)
45        {
46            if (caseSensitiv)
47                regex = regex.ToLower();
48
49            //convert regex to NFA:
50            int index = -1;
51            NFA nfa = RegexToNFA(regex, ref index);
52            if (index < regex.Length)
53            {               
54                throw new ParseException("Error occurred while parsing the regular expression!");
55            }
56
57            if (nfa != null)
58            {
59                //make epsilon transitions superfluous:
60                nfa.RemoveEpsilonTransitions();
61
62                //convert NFA to DFA:
63                transitionMatrix = nfa.GetDFATransitionMatrix(out startIndex);
64
65                //manipulate transition matrix to make it case sensitiv:
66                if (caseSensitiv)
67                {
68                    for (int x = 0; x < transitionMatrix.Length; x++)
69                    {
70                        for (int y = 0; y <= ('Z' - 'A'); y++)
71                            transitionMatrix[x][(byte)(y + 'A')] = transitionMatrix[x][(byte)(y + 'a')];
72                        transitionMatrix[x][(byte)('Ä')] = transitionMatrix[x][(byte)('ä')];
73                        transitionMatrix[x][(byte)('Ö')] = transitionMatrix[x][(byte)('ö')];
74                        transitionMatrix[x][(byte)('Ü')] = transitionMatrix[x][(byte)('ü')];
75                    }
76                }
77            }
78        }
79
80        /// <summary>
81        /// This method modifies the given OpenCL code, so that the returning code includes
82        /// the costfunction RegEx calculations.
83        /// </summary>
84        /// <param name="code"></param>
85        /// <param name="bytesToUse"></param>
86        /// <returns></returns>
87        public string ModifyOpenCLCode(string code, int bytesToUse)
88        {
89            if (transitionMatrix == null)
90            {
91                //return false:
92                return code.Replace("$$COSTFUNCTIONDECLARATIONS$$", "").Replace("$$COSTFUNCTIONINITIALIZE$$", "results[x] = -1.0f; return;")
93                    .Replace("$$COSTFUNCTIONCALCULATE$$", "").Replace("$$COSTFUNCTIONRESULTCALCULATION$$", "result = -1.0f;");
94            }
95
96            //declaration code:
97            string declaration = string.Format("__constant int transitionMatrix[{0}] = {{ \n", transitionMatrix.Length * 256);
98            foreach (var row in transitionMatrix)
99            {
100                foreach (var i in row)
101                {
102                    declaration += i + ", ";
103                }
104            }
105            declaration = declaration.Substring(0, declaration.Length - 2);
106            declaration += " }; \n";
107            declaration += string.Format("__constant int NOTRANSITION = {0}; \n", int.MaxValue);
108
109            code = code.Replace("$$COSTFUNCTIONDECLARATIONS$$", declaration);
110
111            //initialization code:
112            code = code.Replace("$$COSTFUNCTIONINITIALIZE$$", string.Format("int state = {0}; \n", startIndex));
113
114            //calculation code:
115            code = code.Replace("$$COSTFUNCTIONCALCULATE$$", "int absState = ((state >= 0) ? state : ~state); \n"
116                + "state = transitionMatrix[absState*256+c]; \n"
117                + "if (state == NOTRANSITION) { results[x] = -1.0f; return;} \n");
118
119            //result calculation code:
120            code = code.Replace("$$COSTFUNCTIONRESULTCALCULATION$$", "if (state < 0) result = 1.0f; else result = -1.0f;");
121
122            return code;
123        }
124
125        /// <summary>
126        /// Matches input with the given regex.
127        /// </summary>
128        /// <param name="input"></param>
129        /// <returns></returns>
130        public bool Matches(byte[] input)
131        {
132            if (transitionMatrix == null)
133                return false;
134
135            int state = startIndex;
136
137            foreach (byte i in input)
138            {
139                int absState = state >= 0 ? state : ~state;
140                state = transitionMatrix[absState][i];
141                if (state == NOTRANSITION)
142                    return false;
143            }
144
145            return (state < 0);
146        }
147
148        #region NFA
149
150        /// <summary>
151        /// Implements a nondeterministic finite automaton and some useful functions.
152        /// </summary>
153        private class NFA
154        {
155            /// <summary>
156            /// Represents the connection to the node Node which is triggered by the TransitionBytes.
157            /// Epsilon can also be a transition reason.
158            /// </summary>
159            private struct Connection
160            {
161                public List<Connection> Node;
162                public byte[] TransitionBytes;
163                public bool Epsilon;               
164            }
165
166            //This list stores for every node a list of his connections to other nodes:
167            private List<List<Connection>> nodeConnections = new List<List<Connection>>();
168
169            //start and end nodes:
170            private int startIndex;
171            private List<int> endIndices = new List<int>();
172
173            /// <summary>
174            /// Concatenates two NFAs
175            /// </summary>
176            /// <param name="nfa1">The first NFA</param>
177            /// <param name="nfa2">The second NFA</param>
178            /// <returns></returns>
179            public static NFA ConcatNFAs(NFA nfa1, NFA nfa2)
180            {
181                if (nfa1 == null && nfa2 == null)
182                    return null;
183                else if (nfa1 == null)
184                    return nfa2;
185                else if (nfa2 == null)
186                    return nfa1;
187
188                NFA result = new NFA();
189                result.nodeConnections = new List<List<Connection>>();
190                result.nodeConnections.AddRange(nfa1.nodeConnections);
191                result.nodeConnections.AddRange(nfa2.nodeConnections);
192
193                //bridge nfa1 and nfa2 by an epsilon connection:
194                Connection bridge = new Connection();
195                bridge.Epsilon = true;
196                bridge.Node = nfa2.nodeConnections[nfa2.startIndex];
197                result.nodeConnections[nfa1.endIndices[0]].Add(bridge);
198
199                result.startIndex = nfa1.startIndex;
200                result.endIndices.Add(nfa1.nodeConnections.Count + nfa2.endIndices[0]);
201
202                return result;
203            }
204
205            public static NFA AlternateNFAs(NFA nfa1, NFA nfa2)
206            {
207                if (nfa1 == null && nfa2 == null)
208                    return null;
209                else if (nfa1 == null)
210                    return nfa2;
211                else if (nfa2 == null)
212                    return nfa1;
213
214                NFA result = new NFA();
215                result.nodeConnections = new List<List<Connection>>();
216                result.nodeConnections.AddRange(nfa1.nodeConnections);
217                result.nodeConnections.AddRange(nfa2.nodeConnections);
218
219                //create new start node:
220                Connection newStartConnection1 = new Connection();
221                newStartConnection1.Epsilon = true;
222                newStartConnection1.Node = result.nodeConnections[nfa1.startIndex];
223                Connection newStartConnection2 = new Connection();
224                newStartConnection2.Epsilon = true;
225                newStartConnection2.Node = result.nodeConnections[nfa1.nodeConnections.Count + nfa2.startIndex];
226                List<Connection> newStart = new List<Connection>() { newStartConnection1, newStartConnection2 };
227                result.nodeConnections.Add(newStart);
228                result.startIndex = result.nodeConnections.Count - 1;
229
230                //create new end node:
231                List<Connection> newEnd = new List<Connection>();
232                result.nodeConnections.Add(newEnd);
233                result.endIndices.Add(result.nodeConnections.Count - 1);
234                Connection newEndConnection1 = new Connection();
235                newEndConnection1.Epsilon = true;
236                newEndConnection1.Node = result.nodeConnections[result.endIndices[0]];
237                result.nodeConnections[nfa1.endIndices[0]].Add(newEndConnection1);
238                Connection newEndConnection2 = new Connection();
239                newEndConnection2.Epsilon = true;
240                newEndConnection2.Node = result.nodeConnections[result.endIndices[0]];
241                result.nodeConnections[nfa1.nodeConnections.Count + nfa2.endIndices[0]].Add(newEndConnection1);
242
243                return result;
244            }
245
246            public void KleeneStar()
247            {
248                //create new end connection:
249                List<Connection> newEnd = new List<Connection>();
250                nodeConnections.Add(newEnd);
251
252                //create new start connection:
253                Connection newStartConnection1 = new Connection();
254                newStartConnection1.Epsilon = true;
255                newStartConnection1.Node = nodeConnections[startIndex];
256                Connection newStartConnection2 = new Connection();
257                newStartConnection2.Epsilon = true;
258                newStartConnection2.Node = newEnd;
259                List<Connection> newStart = new List<Connection>() { newStartConnection1, newStartConnection2 };
260                nodeConnections.Add(newStart);
261                startIndex = nodeConnections.Count - 1;
262
263                //connects old end with start:
264                Connection oldEndWithStartConnection = new Connection();
265                oldEndWithStartConnection.Epsilon = true;
266                oldEndWithStartConnection.Node = newStart;
267                nodeConnections[endIndices[0]].Add(oldEndWithStartConnection);
268
269                endIndices[0] = nodeConnections.IndexOf(newEnd);
270            }
271
272            public static NFA ByteTransitionNFA(byte[] transitionBytes)
273            {
274                NFA result = new NFA();
275
276                List<Connection> newEnd = new List<Connection>();
277                Connection transition = new Connection();
278                transition.Epsilon = false;
279                transition.Node = newEnd;
280                transition.TransitionBytes = transitionBytes;
281                List<Connection> newStart = new List<Connection> { transition };
282
283                result.nodeConnections.Add(newStart);
284                result.nodeConnections.Add(newEnd);
285                result.startIndex = 0;
286                result.endIndices.Add(1);
287
288                return result;
289            }
290
291            public override string ToString()
292            {
293                int i = 0;
294                Dictionary<List<Connection>, int> nodeToID = new Dictionary<List<Connection>, int>();
295                string res = "";
296
297                foreach (List<Connection> node in nodeConnections)
298                {
299                    nodeToID.Add(node, i++);
300                }
301
302                res += "Start: " + nodeToID[nodeConnections[startIndex]] + "\n";
303                res += "End: (";
304                foreach (int ei in endIndices)
305                    res += nodeToID[nodeConnections[ei]] + ", ";
306                res += ")\n";
307
308                foreach (List<Connection> node in nodeConnections)
309                {
310                    res += nodeToID[node] + ": ";
311                    foreach (Connection c in node)
312                        res += " -" + ListArray(c.TransitionBytes) + "-" + (c.Epsilon ? "e" : "") + "-> " + nodeToID[c.Node] + " ";
313                    res += "\n";
314                }
315
316                return res;
317            }
318
319            private string ListArray(byte[] p)
320            {
321                if (p == null)
322                    return "";
323
324                string res = "[";
325                foreach (byte b in p)
326                    res += Convert.ToChar(b);
327                return res + "]";
328            }
329
330            public void RemoveEpsilonTransitions()
331            {
332                bool update = true;
333                               
334                while (update)
335                {
336                    update = false;
337                    //for all epsilon connections, we create a new connection that "jumps" over the intermediate node:
338                    foreach (var node in nodeConnections)
339                    {
340                        for (int i = 0; i < node.Count; i++)
341                        {
342                            var connection = node[i];
343                            if (connection.Epsilon)
344                            {
345                                foreach (var c in connection.Node)
346                                {
347                                    if (!node.Contains(c))
348                                    {
349                                        node.Add(c);
350                                        update = true;
351                                    }
352                                }
353
354                                //if we have an epsilon connection to an end node:
355                                if (endIndices.Contains(nodeConnections.IndexOf(connection.Node)))
356                                {
357                                    int index = nodeConnections.IndexOf(node);
358                                    if (!endIndices.Contains(index))
359                                        endIndices.Add(index);
360                                }
361                            }
362                        }
363                    }
364                }
365
366                //remove all epsilon connections:
367                foreach (var node in nodeConnections)
368                {
369                    for (int i = node.Count - 1; i >= 0; i--)
370                    {
371                        if (node[i].Epsilon)
372                        {
373                            if (node[i].TransitionBytes == null)
374                                node.RemoveAt(i);
375                            else
376                            {
377                                Connection c = node[i];
378                                c.Epsilon = false;
379                            }
380                        }
381                    }
382                }
383
384            }
385
386
387            public int[][] GetDFATransitionMatrix(out int start)
388            {
389                List<int[]> transitions = new List<int[]>();
390                Dictionary<HashSet<List<Connection>>, int> states = new Dictionary<HashSet<List<Connection>>, int>();   //mapping of node powersets to state id
391                int idCounter = 0;
392
393                start = idCounter++;
394                var startState = new HashSet<List<Connection>>(new List<Connection>[] { nodeConnections[startIndex] });
395                if (endIndices.Contains(startIndex))
396                    start = ~start;
397                states.Add(startState, start);
398
399                int[] newTransitions = new int[256];               
400
401                Queue<HashSet<List<Connection>>> newStates = new Queue<HashSet<List<Connection>>>();
402                newStates.Enqueue(startState);
403                               
404                while (newStates.Count != 0)
405                {
406                    var state = newStates.Dequeue();
407
408                    //create a new row in transition matrix for this state:
409                    newTransitions = new int[256];
410                    for (int i = 0; i < 256; i++)
411                        newTransitions[i] = NOTRANSITION;
412                    transitions.Add(newTransitions);
413
414                    //we have to check the transistions for every possible transition byte.
415                    for (int b = byte.MinValue; b <= byte.MaxValue; b++)
416                    {
417                        //check which nodes we would reach by transition over b. Make a state out of these node powerset.
418                        var transitionState = new HashSet<List<Connection>>();
419                        foreach (var node in state)
420                        {
421                            foreach (var connection in node)
422                                if (connection.TransitionBytes.Contains((byte)b))
423                                    transitionState.Add(connection.Node);
424                        }
425                        if (transitionState.Count == 0)
426                            continue;
427
428                        int transitionStateIndex = NOTRANSITION;
429
430                        //check if this state already exists:
431                        foreach (var s in states.Keys)
432                        {
433                            if (s.SetEquals(transitionState))
434                            {
435                                transitionStateIndex = states[s];
436                                break;
437                            }
438                        }
439
440                        //if not, we have to create it:
441                        if (transitionStateIndex == NOTRANSITION)
442                        {
443                            transitionStateIndex = idCounter++;
444
445                            foreach (var node in transitionState)   //check if this is an end state
446                            {
447                                if (endIndices.Contains(nodeConnections.IndexOf(node)))
448                                {
449                                    transitionStateIndex = ~transitionStateIndex;
450                                    break;
451                                }
452                            }
453
454                            states.Add(transitionState, transitionStateIndex);
455                            newStates.Enqueue(transitionState);                           
456                        }
457
458                        //put transition into table:
459                        int absState = states[state] >= 0 ? states[state] : ~states[state];
460                        transitions[absState][b] = transitionStateIndex;
461                    }
462
463                   
464                }
465
466                return transitions.ToArray();
467            }
468        }
469
470        #endregion
471       
472        /// <summary>
473        /// Parses the regex and converts it into an NFA.
474        /// </summary>
475        /// <param name="regex"></param>
476        private NFA RegexToNFA(string regex, ref int index)
477        {
478            bool bracket = false;
479            if (index >= 0 && index < regex.Length)
480                bracket = (regex[index] == '(');
481
482            if (bracket || (index < 0))
483                index++;
484
485            NFA result = null;
486
487            while (index < regex.Length && regex[index] != ')')
488            {
489                NFA newNFA = null;
490                int combineMode = -1;   //0 = concat, 1 = alternate
491                switch (regex[index])
492                {
493                    case '(':
494                        newNFA = RegexToNFA(regex, ref index);
495                        combineMode = 0;
496                        break;
497                    case '|':
498                        index++;
499                        newNFA = RegexToNFA(regex, ref index);
500                        combineMode = 1;
501                        break;
502                    case '*':
503                        index++;
504                        combineMode = -1;
505                        break;
506                    case '[':
507                        byte[] transitionBytes = GetTransitionSet(regex, ref index);
508                        newNFA = NFA.ByteTransitionNFA(transitionBytes);
509                        combineMode = 0;
510                        break;
511                    case '.':
512                        transitionBytes = new byte[256];
513                        for (int b = byte.MinValue; b <= byte.MaxValue; b++)
514                            transitionBytes[b] = (byte)b;
515                        newNFA = NFA.ByteTransitionNFA(transitionBytes);
516                        combineMode = 0;
517                        index++;
518                        break;
519                    case ']':
520                        throw new ParseException("Error at position " + index + "! No ']' expected. Escape it!");
521                    case '\\':
522                        index++;
523                        transitionBytes = new byte[] { Convert.ToByte(regex[index++]) };
524                        newNFA = NFA.ByteTransitionNFA(transitionBytes);
525                        combineMode = 0;
526                        break;
527                    default:
528                        transitionBytes = new byte[] { Convert.ToByte(regex[index++]) };
529                        newNFA = NFA.ByteTransitionNFA(transitionBytes);
530                        combineMode = 0;
531                        break;
532                }
533
534                if (index < regex.Length && regex[index] == '*')
535                {
536                    newNFA.KleeneStar();
537                    index++;
538                }
539
540                if (combineMode == 0)
541                    result = NFA.ConcatNFAs(result, newNFA);
542                else if (combineMode == 1)
543                    result = NFA.AlternateNFAs(result, newNFA);
544            }
545
546
547            if (bracket && (index < regex.Length) && regex[index] == ')')
548                index++;
549            else if (bracket && ((index >= regex.Length) || regex[index] != ')'))
550                throw new ParseException("Error at position " + index + "! Closing bracket expected!");
551
552            return result;
553        }
554
555        private byte[] GetTransitionSet(string regex, ref int index)
556        {
557            List<byte> set = new List<byte>();
558            index++;
559            bool invert = false;
560            if (regex[index] == '^')
561            {
562                index++;
563                invert = true;
564            }
565
566            while (index < regex.Length && regex[index] != ']')
567            {
568                if (index + 2 < regex.Length && regex[index + 1] == '-' && regex[index + 1] != ']')     //group
569                {
570                    for (byte b = Convert.ToByte(regex[index]); b <= Convert.ToByte(regex[index+2]); b++)
571                        set.Add(b);
572                    index += 3;
573                }
574                else if (index + 1 < regex.Length && regex[index + 1] == '-')   //misuse of '-' group sign
575                {
576                    throw new ParseException("Error at position " + (index + 1) + "! Misuse of '-' in group!");
577                }
578                else    //normal case
579                {
580                    //TODO: Maybe we should check here, if character is allowed at this position.
581                    set.Add(Convert.ToByte(regex[index]));
582                    index++;
583                }
584            }
585
586            if (index >= regex.Length)
587                throw new ParseException("Error at the end of the expression! ']' expected!");
588
589            byte[] result;
590            if (!invert)
591                result = set.ToArray();
592            else
593            {
594                List<byte> set2 = new List<byte>();
595                for (int b = byte.MinValue; b <= byte.MaxValue; b++)
596                    if (!set.Contains((byte)b))
597                        set2.Add((byte)b);
598                result = set2.ToArray();
599            }
600
601            index++;
602
603            return result;
604        }
605
606    }
607
608    public class ParseException : Exception
609    {
610        public ParseException(string message) : base(message)
611        {
612        }
613    }
614
615   
616}
Note: See TracBrowser for help on using the repository browser.