Skip to content

Commit 0d0361c

Browse files
authored
Merge pull request #80 from alanjds/wildcard-try1
Fixes `from time import *; print sleep`
2 parents 8539169 + 4c77928 commit 0d0361c

File tree

6 files changed

+72
-12
lines changed

6 files changed

+72
-12
lines changed

grumpy-runtime-src/runtime/module.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,58 @@ func importOne(f *Frame, name string) (*Object, *BaseException) {
178178
return o, nil
179179
}
180180

181+
// LoadMembers scans over all the members in module
182+
// and populates globals with them, taking __all__ into
183+
// account.
184+
func LoadMembers(f *Frame, module *Object) *BaseException {
185+
allAttr, raised := GetAttr(f, module, NewStr("__all__"), nil)
186+
if raised != nil && !raised.isInstance(AttributeErrorType) {
187+
return raised
188+
}
189+
f.RestoreExc(nil, nil)
190+
191+
if raised == nil {
192+
raised = loadMembersFromIterable(f, module, allAttr, nil)
193+
if raised != nil {
194+
return raised
195+
}
196+
return nil
197+
}
198+
199+
// Fall back on __dict__
200+
dictAttr := module.dict.ToObject()
201+
raised = loadMembersFromIterable(f, module, dictAttr, func(key *Object) bool {
202+
return strings.HasPrefix(toStrUnsafe(key).value, "_")
203+
})
204+
if raised != nil {
205+
return raised
206+
}
207+
return nil
208+
}
209+
210+
func loadMembersFromIterable(f *Frame, module, iterable *Object, filterF func(*Object) bool) *BaseException {
211+
globals := f.Globals()
212+
raised := seqForEach(f, iterable, func(memberName *Object) *BaseException {
213+
if !memberName.isInstance(StrType) {
214+
errorMessage := fmt.Sprintf("attribute name must be string, not '%v'", memberName.typ.Name())
215+
return f.RaiseType(AttributeErrorType, errorMessage)
216+
}
217+
member, raised := GetAttr(f, module, toStrUnsafe(memberName), nil)
218+
if raised != nil {
219+
return raised
220+
}
221+
if filterF != nil && filterF(memberName) {
222+
return nil
223+
}
224+
raised = globals.SetItem(f, memberName, member)
225+
if raised != nil {
226+
return raised
227+
}
228+
return nil
229+
})
230+
return raised
231+
}
232+
181233
// newModule creates a new Module object with the given fully qualified name
182234
// (e.g a.b.c) and its corresponding Python filename and package.
183235
func newModule(name, filename string) *Module {

grumpy-runtime-src/runtime/module_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ func TestImportModule(t *testing.T) {
184184
}
185185
}
186186
}
187-
188187
func TestModuleGetNameAndFilename(t *testing.T) {
189188
fun := wrapFuncForTest(func(f *Frame, m *Module) (*Tuple, *BaseException) {
190189
name, raised := m.GetName(f)

grumpy-tools-src/grumpy_tools/compiler/imputil.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Import(object):
5151

5252
MODULE = "<BindType 'module'>"
5353
MEMBER = "<BindType 'member'>"
54+
STAR = "<BindType 'star'>"
5455

5556
def __init__(self, name, script=None, is_native=False):
5657
self.name = name
@@ -109,7 +110,14 @@ def visit_Import(self, node):
109110

110111
def visit_ImportFrom(self, node):
111112
if any(a.name == '*' for a in node.names):
112-
raise util.ImportError(node, 'wildcard member import is not implemented')
113+
if len(node.names) != 1:
114+
# TODO: Change to SyntaxError, as CPython does on "from foo import *, bar"
115+
raise util.ImportError(node, 'invalid syntax on wildcard import')
116+
117+
# Imported name is * (star). Will bind __all__ the module contents.
118+
imp = self._resolve_import(node, node.module)
119+
imp.add_binding(Import.STAR, '*', imp.name.count('.'))
120+
return [imp]
113121

114122
if not node.level and node.module == '__future__':
115123
return []

grumpy-tools-src/grumpy_tools/compiler/imputil_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,8 @@ def testImportFromAsMembers(self):
174174
imp.add_binding(imputil.Import.MEMBER, 'baz', 'bar')
175175
self._check_imports('from foo import bar as baz', [imp])
176176

177-
def testImportFromWildcardRaises(self):
178-
self.assertRaises(util.ImportError, self.importer.visit,
179-
pythonparser.parse('from foo import *').body[0])
177+
# def testImportFromWildcardRaises(self):
178+
# self._check_imports('from foo import *', [])
180179

181180
def testImportFromFuture(self):
182181
self._check_imports('from __future__ import print_function', [])

grumpy-tools-src/grumpy_tools/compiler/stmt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,8 @@ def _import_and_bind(self, imp):
629629
self.writer.write('{} = {}[{}]'.format(
630630
mod.name, mod_slice.expr, binding.value))
631631
self.block.bind_var(self.writer, binding.alias, mod.expr)
632+
elif binding.bind_type == imputil.Import.STAR:
633+
self.writer.write_checked_call1('πg.LoadMembers(πF, {}[0])', mod_slice.name)
632634
else:
633635
self.writer.write('{} = {}[{}]'.format(
634636
mod.name, mod_slice.expr, imp.name.count('.')))

grumpy-tools-src/grumpy_tools/compiler/stmt_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -333,19 +333,19 @@ def testImportNativeType(self):
333333
from "__go__/time" import Duration
334334
print Duration""")))
335335

336-
def testImportWildcardMemberRaises(self):
337-
regexp = 'wildcard member import is not implemented'
338-
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
339-
'from foo import *')
340-
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
341-
'from "__go__/foo" import *')
342-
343336
def testPrintStatement(self):
344337
self.assertEqual((0, 'abc 123\nfoo bar\n'), _GrumpRun(textwrap.dedent("""\
345338
print 'abc',
346339
print '123'
347340
print 'foo', 'bar'""")))
348341

342+
def testImportWildcard(self):
343+
result = _GrumpRun(textwrap.dedent("""\
344+
from time import *
345+
print sleep"""))
346+
self.assertEqual(0, result[0])
347+
self.assertIn('<function sleep at', result[1])
348+
349349
def testPrintFunction(self):
350350
want = "abc\n123\nabc 123\nabcx123\nabc 123 "
351351
self.assertEqual((0, want), _GrumpRun(textwrap.dedent("""\

0 commit comments

Comments
 (0)