from types import ClassType, BuiltinFunctionType
from keyword import iskeyword
import threading
from Common import *
from Servlet import Servlet
debug = False
class ServletFactory(Object):
"""Servlet factory template.
ServletFactory is an abstract class that defines the protocol for
all servlet factories.
Servlet factories are used by the Application to create servlets
for transactions.
A factory must inherit from this class and override uniqueness(),
extensions() and either loadClass() or servletForTransaction().
Do not invoke the base class methods as they all raise AbstractErrors.
Each method is documented below.
"""
def __init__(self, application):
"""Create servlet factory.
Stores a reference to the application in self._app, because
subclasses may or may not need to talk back to the application
to do their work.
"""
Object.__init__(self)
self._app = application
self._imp = self._app._imp
self._cacheClasses = self._app.setting("CacheServletClasses", True)
self._cacheInstances = self._app.setting("CacheServletInstances", True)
self._classCache = {}
self._servletPool = {}
self._threadsafeServletCache = {}
self._importLock = threading.RLock()
def name(self):
"""Return the name of the factory.
This is a convenience for the class name.
"""
return self.__class__.__name__
def uniqueness(self):
"""Return uniqueness type.
Returns a string to indicate the uniqueness of the ServletFactory's
servlets. The Application needs to know if the servlets are unique
per file, per extension or per application.
Return values are 'file', 'extension' and 'application'.
NOTE: Application only supports 'file' uniqueness at this point in time.
"""
raise AbstractError, self.__class__
def extensions(self):
"""Return a list of extensions that match this handler.
Extensions should include the dot. An empty string indicates a file
with no extension and is a valid value. The extension '.*' is a special
case that is looked for a URL's extension doesn't match anything.
"""
raise AbstractError, self.__class__
def importAsPackage(self, transaction, serverSidePathToImport):
"""Import requested module.
Imports the module at the given path in the proper package/subpackage
for the current request. For example, if the transaction has the URL
http://localhost/WebKit.cgi/MyContextDirectory/MySubdirectory/MyPage
and path = 'some/random/path/MyModule.py' and the context is configured
to have the name 'MyContext' then this function imports the module at
that path as MyContext.MySubdirectory.MyModule . Note that the context
name may differ from the name of the directory containing the context,
even though they are usually the same by convention.
Note that the module imported may have a different name from the
servlet name specified in the URL. This is used in PSP.
"""
request = transaction.request()
path = request.serverSidePath()
contextPath = request.serverSideContextPath()
fullname = request.contextName()
if not fullname or not path.startswith(contextPath):
remainder = serverSidePathToImport
fullmodname = remainder.replace(
'\\', '_').replace('/', '_').replace('.', '_')
if debug:
print __file__, ", fullmodname =", fullmodname
modname = os.path.splitext(os.path.basename(
serverSidePathToImport))[0]
fp, pathname, stuff = self._imp.find_module(modname,
[os.path.dirname(serverSidePathToImport)])
module = self._imp.load_module(fullmodname, fp, pathname, stuff)
module.__donotreload__ = True
return module
directory, contextDirName = os.path.split(contextPath)
self._importModuleFromDirectory(fullname, contextDirName,
directory, isPackageDir=True)
directory = contextPath
remainder = path[len(contextPath)+1:].replace('\\', '/')
remainder = remainder.split('/')
for name in remainder[:-1]:
fullname = '%s.%s' % (fullname, name)
self._importModuleFromDirectory(fullname, name,
directory, isPackageDir=True)
directory = os.path.join(directory, name)
moduleFileName = os.path.basename(serverSidePathToImport)
moduleDir = os.path.dirname(serverSidePathToImport)
name = os.path.splitext(moduleFileName)[0]
fullname = '%s.%s' % (fullname, name)
module = self._importModuleFromDirectory(fullname, name,
moduleDir, forceReload=True)
return module
def _importModuleFromDirectory(self, fullModuleName, moduleName,
directory, isPackageDir=False, forceReload=False):
"""Imports the given module from the given directory.
fullModuleName should be the full dotted name that will be given
to the module within Python. moduleName should be the name of the
module in the filesystem, which may be different from the name
given in fullModuleName. Returns the module object. If forceReload is
True then this reloads the module even if it has already been imported.
If isPackageDir is True, then this function creates an empty
__init__.py if that file doesn't already exist.
"""
if debug:
print __file__, fullModuleName, moduleName, directory
if not forceReload:
module = sys.modules.get(fullModuleName, None)
if module is not None:
return module
fp = None
if isPackageDir:
packageDir = os.path.join(directory, moduleName)
initPy = os.path.join(packageDir, '__init__.py')
for ext in ('', 'c', 'o'):
if os.path.exists(initPy + ext):
break
else:
file = open(initPy, 'w')
file.write('#')
file.close()
fp, pathname, stuff = self._imp.find_module(moduleName, [directory])
module = self._imp.load_module(fullModuleName, fp, pathname, stuff)
module.__donotreload__ = True
return module
def loadClass(self, transaction, path):
"""Load the appropriate class.
Given a transaction and a path, load the class for creating these
servlets. Caching, pooling, and threadsafeness are all handled by
servletForTransaction. This method is not expected to be threadsafe.
"""
raise AbstractError, self.__class__
def servletForTransaction(self, transaction):
"""Return a new servlet that will handle the transaction.
This method handles caching, and will call loadClass(trans, filepath)
if no cache is found. Caching is generally controlled by servlets
with the canBeReused() and canBeThreaded() methods.
"""
request = transaction.request()
path = request.serverSidePath()
mtime = os.path.getmtime(path)
if not self._classCache.has_key(path) or \
mtime != self._classCache[path]['mtime']:
self._importLock.acquire()
try:
if not self._classCache.has_key(path) or \
mtime != self._classCache[path]['mtime']:
theClass = self.loadClass(transaction, path)
if self._cacheClasses:
self._classCache[path] = {
'mtime': mtime, 'class': theClass}
else:
theClass = self._classCache[path]['class']
finally:
self._importLock.release()
else:
theClass = self._classCache[path]['class']
if self._threadsafeServletCache.has_key(path):
servlet = self._threadsafeServletCache[path]
if servlet.__class__ is theClass:
return servlet
else:
while 1:
try:
servlet = self._servletPool[path].pop()
except (KeyError, IndexError):
break
else:
if servlet.__class__ is theClass:
servlet.open()
return servlet
self._importLock.acquire()
try:
mtime = os.path.getmtime(path)
if not self._classCache.has_key(path):
self._classCache[path] = {
'mtime': mtime,
'class': self.loadClass(transaction, path)}
elif mtime > self._classCache[path]['mtime']:
self._classCache[path]['mtime'] = mtime
self._classCache[path]['class'] = self.loadClass(
transaction, path)
theClass = self._classCache[path]['class']
if not self._cacheClasses:
del self._classCache[path]
finally:
self._importLock.release()
servlet = theClass()
servlet.setFactory(self)
if servlet.canBeReused():
if servlet.canBeThreaded():
self._threadsafeServletCache[path] = servlet
else:
self._servletPool[path] = []
servlet.open()
return servlet
def returnServlet(self, servlet):
"""Return servlet to the pool.
Called by Servlet.close(), which returns the servlet
to the servlet pool if necessary.
"""
if servlet.canBeReused() and not servlet.canBeThreaded() \
and self._cacheInstances:
path = servlet.serverSidePath()
self._servletPool[path].append(servlet)
def flushCache(self):
"""Flush the servlet cache and start fresh.
Servlets that are currently in the wild may find their way back
into the cache (this may be a problem).
"""
self._importLock.acquire()
self._classCache = {}
for key in self._servletPool.keys():
self._servletPool[key] = []
self._threadsafeServletCache = {}
self._importLock.release()
class PythonServletFactory(ServletFactory):
"""The factory for Python servlets.
This is the factory for ordinary Python servlets whose extensions
are empty or .py. The servlets are unique per file since the file
itself defines the servlet.
"""
def uniqueness(self):
return 'file'
def extensions(self):
return ['.py', '.pyc', '.pyo']
def loadClass(self, transaction, path):
module = self.importAsPackage(transaction, path)
name = os.path.splitext(os.path.split(path)[1])[0]
if not hasattr(module, name):
name = name.replace('-', '_').replace(' ', '_')
if iskeyword(name):
name += '_'
if not hasattr(module, name):
raise ValueError, \
'Cannot find expected servlet class %r in %r.' \
% (name, path)
theClass = getattr(module, name)
if type(type) is BuiltinFunctionType:
assert type(theClass) is ClassType
else:
assert type(theClass) is ClassType \
or isinstance(theClass, type)
assert issubclass(theClass, Servlet)
return theClass