tkbe

November 10, 2006

Python :: flatten

Filed under: python — tb @ 10:52 am

Although it can seem like a CS excercise, everyone will sooner or later have to write a "flatten" function -- a function that takes a nested "list" and makes it one-dimensional:

PYTHON:
  1. >>> flatten([1,2, [3,4], 5])
  2. [1, 2, 3, 4, 5]

it should also handle Python datatypes in a reasonable manner:

PYTHON:
  1. >>> flatten(['hello', 'world'])
  2. ['hello', 'world']
  3. >>> flatten(i**2 for i in range(10))
  4. [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

The trick to the implementation is to realize that everything that is iterable should be flattened... except strings and unicode. To test if something is iterable, call iter on it and be prepared to catch a TypeError: iterator over non-sequence. To test if something is string or unicode use isinstance(item, basestring). That results in something like this:

PYTHON:
  1. def flatten(lst):
  2.     def _flatten(lst, res):
  3.         for item in lst:
  4.             if isinstance(item, basestring):
  5.                 res.append(item)
  6.             else:
  7.                 try:
  8.                     _flatten(iter(item), res)
  9.                 except TypeError: # iterator over nonsequence
  10.                     res.append(item)
  11.         return res
  12.     return _flatten(lst, [])

That generates the entire list in the res variable of the inner function as it's recursing through the datastructure. That could be a very long list, so perhaps a generator works better?:

PYTHON:
  1. def flatten(lst):
  2.     for item in lst:
  3.         if isinstance(item, basestring):
  4.             yield item
  5.         else:
  6.             try:
  7.                 flatten(iter(item))
  8.             except TypeError: # iterator over nonsequence
  9.                 yield item
  10.     return

That looks pretty good :-)

So what was the reason I had to write flatten this time?...

PYTHON:
  1. >>> from html import table, tr, td
  2. >>> td_num = html.mktag('td', width="20%", style="background:aqua")
  3. >>> table(tr(td_num(x), td(x**2)) for x in range(6))

Which gives

0 0
1 1
2 4
3 9
4 16
5 25

the __str__ for all the classes is only defined for the base class, which needed to flatten its contents...

No Comments »

No comments yet.

RSS feed for comments on this post. TrackBack URI

Leave a comment

Powered by WordPress