Commit a00f378d authored by Josh Roesslein's avatar Josh Roesslein
Browse files

2 space -> 4 space indents

parent 4b126837
......@@ -5,6 +5,7 @@ from getpass import getpass
import tweepy
class StreamWatcherListener(tweepy.StreamListener):
def on_status(self, status):
......@@ -56,6 +57,5 @@ while True:
# Shutdown connection
stream.disconnect()
print 'Bye!'
......@@ -5,17 +5,18 @@
import unittest
import random
from time import sleep
import os
from tweepy import *
"""Configurations"""
# Must supply twitter account credentials for tests
username = ''
password = ''
username = 'tweebly'
password = 'josh1987twitter'
"""Unit tests"""
# API tests
class TweepyAPITests(unittest.TestCase):
def setUp(self):
......@@ -43,7 +44,7 @@ class TweepyAPITests(unittest.TestCase):
def testupdateanddestroystatus(self):
# test update
text = 'testing %i' % random.randint(0,1000)
text = 'testing %i' % random.randint(0, 1000)
update = self.api.update_status(status=text)
self.assertEqual(update.text, text)
......@@ -100,7 +101,7 @@ class TweepyAPITests(unittest.TestCase):
self.assert_(isinstance(source, Friendship))
self.assert_(isinstance(target, Friendship))
# Authentication tests
class TweepyAuthTests(unittest.TestCase):
consumer_key = 'ZbzSsdQj7t68VYlqIFvdcA'
......@@ -120,17 +121,16 @@ class TweepyAuthTests(unittest.TestCase):
# build api object test using oauth
api = API(auth)
api.update_status('test %i' % random.randint(0,1000))
api.update_status('test %i' % random.randint(0, 1000))
def testbasicauth(self):
auth = BasicAuthHandler(username, password)
# test accessing twitter API
api = API(auth)
api.update_status('test %i' % random.randint(1,1000))
api.update_status('test %i' % random.randint(1, 1000))
# Cache tests
class TweepyCacheTests(unittest.TestCase):
timeout = 2.0
......@@ -139,11 +139,13 @@ class TweepyCacheTests(unittest.TestCase):
def _run_tests(self, do_cleanup=True):
# test store and get
self.cache.store('testkey', 'testvalue')
self.assertEqual(self.cache.get('testkey'), 'testvalue', 'Stored value does not match retrieved value')
self.assertEqual(self.cache.get('testkey'), 'testvalue',
'Stored value does not match retrieved value')
# test timeout
sleep(self.timeout)
self.assertEqual(self.cache.get('testkey'), None, 'Cache entry should have expired')
self.assertEqual(self.cache.get('testkey'), None,
'Cache entry should have expired')
# test cleanup
if do_cleanup:
......@@ -153,7 +155,7 @@ class TweepyCacheTests(unittest.TestCase):
self.assertEqual(self.cache.count(), 0, 'Cache cleanup failed')
# test count
for i in range(0,20):
for i in range(0, 20):
self.cache.store('testkey%i' % i, 'testvalue')
self.assertEqual(self.cache.count(), 20, 'Count is wrong')
......@@ -177,5 +179,6 @@ class TweepyCacheTests(unittest.TestCase):
self._run_tests(do_cleanup=False)
if __name__ == '__main__':
unittest.main()
......@@ -16,3 +16,4 @@ from . streaming import Stream, StreamListener
# Global, unauthenticated instance of API
api = API()
......@@ -10,8 +10,9 @@ from . error import TweepError
from . auth import BasicAuthHandler, OAuthHandler
from tweepy.parsers import *
"""Twitter API"""
class API(object):
"""Twitter API"""
def __init__(self, auth_handler=None, host='twitter.com', cache=None,
secure=False, api_root='', validate=True):
......@@ -106,9 +107,13 @@ class API(object):
raise TweepError('Authentication required')
try:
user = bind_api(path='/account/verify_credentials.json', parser=parse_user)(self)
user = bind_api(
path = '/account/verify_credentials.json',
parser = parse_user
)(self)
except TweepError, e:
raise TweepError('Failed to fetch username: %s' % e)
self._username = user.screen_name
return self.get_user(screen_name=self._username)
......@@ -215,7 +220,8 @@ class API(object):
return bind_api(
path = '/account/verify_credentials.json',
parser = parse_return_true,
require_auth = True)(self)
require_auth = True
)(self)
except TweepError:
return False
......@@ -347,6 +353,7 @@ class API(object):
)(self, **kargs)
except TweepError:
return False
return True
"""Get list of users that are blocked"""
......@@ -421,9 +428,7 @@ class API(object):
parser = parse_trend_results
)(self)
""" Pack image file into multipart-formdata request"""
def _pack_image(filename, max_size):
def _pack_image(filename, max_size):
"""Pack image from file into multipart-formdata post body"""
# image must be less than 700kb in size
try:
......@@ -437,7 +442,7 @@ def _pack_image(filename, max_size):
if file_type is None:
raise TweepError('Could not determine file type')
file_type = file_type[0]
if file_type != 'image/gif' and file_type != 'image/jpeg' and file_type != 'image/png':
if file_type not in ['image/gif', 'image/jpeg', 'image/png']:
raise TweepError('Invalid file type for image: %s' % file_type)
# build the mulitpart-formdata body
......
......@@ -3,18 +3,19 @@
# See LICENSE
from urllib2 import Request, urlopen
from urllib import quote
import base64
from . import oauth
from . error import TweepError
class AuthHandler(object):
def apply_auth(self, url, method, headers, parameters):
"""Apply authentication headers to request"""
raise NotImplemented
class BasicAuthHandler(AuthHandler):
def __init__(self, username, password):
......@@ -23,8 +24,9 @@ class BasicAuthHandler(AuthHandler):
def apply_auth(self, url, method, headers, parameters):
headers['Authorization'] = 'Basic %s' % self._b64up
"""OAuth authentication handler"""
class OAuthHandler(AuthHandler):
"""OAuth authentication handler"""
REQUEST_TOKEN_URL = 'http://twitter.com/oauth/request_token'
AUTHORIZATION_URL = 'http://twitter.com/oauth/authorize'
......@@ -38,19 +40,21 @@ class OAuthHandler(AuthHandler):
self.callback = callback
def apply_auth(self, url, method, headers, parameters):
request = oauth.OAuthRequest.from_consumer_and_token(self._consumer,
http_url=url, http_method=method, token=self.access_token, parameters=parameters)
request = oauth.OAuthRequest.from_consumer_and_token(
self._consumer, http_url=url, http_method=method,
token=self.access_token, parameters=parameters
)
request.sign_request(self._sigmethod, self._consumer, self.access_token)
headers.update(request.to_header())
def _get_request_token(self):
try:
request = oauth.OAuthRequest.from_consumer_and_token(self._consumer,
http_url = self.REQUEST_TOKEN_URL, callback=self.callback)
request = oauth.OAuthRequest.from_consumer_and_token(
self._consumer, http_url=self.REQUEST_TOKEN_URL, callback=self.callback
)
request.sign_request(self._sigmethod, self._consumer, None)
resp = urlopen(Request(self.REQUEST_TOKEN_URL, headers=request.to_header()))
return oauth.OAuthToken.from_string(resp.read())
except Exception, e:
raise TweepError(e)
......@@ -65,18 +69,25 @@ class OAuthHandler(AuthHandler):
# build auth request and return as url
request = oauth.OAuthRequest.from_token_and_callback(
token=self.request_token, http_url=self.AUTHORIZATION_URL)
return request.to_url()
token=self.request_token, http_url=self.AUTHORIZATION_URL
)
return request.to_url()
except Exception, e:
raise TweepError(e)
def get_access_token(self, verifier):
"""After user has authorized the request token, get access token with user supplied verifier."""
"""
After user has authorized the request token, get access token
with user supplied verifier.
"""
try:
# build request
request = oauth.OAuthRequest.from_consumer_and_token(self._consumer,
token=self.request_token, http_url=self.ACCESS_TOKEN_URL, verifier=str(verifier))
request = oauth.OAuthRequest.from_consumer_and_token(
self._consumer,
token=self.request_token, http_url=self.ACCESS_TOKEN_URL,
verifier=str(verifier)
)
request.sign_request(self._sigmethod, self._consumer, self.request_token)
# send request
......
......@@ -8,6 +8,7 @@ import urllib
from . parsers import parse_error
from . error import TweepError
def bind_api(path, parser, allowed_param=None, method='GET', require_auth=False,
timeout=None, host=None):
......@@ -71,8 +72,8 @@ def bind_api(path, parser, allowed_param=None, method='GET', require_auth=False,
# Check cache if caching enabled and method is GET
if api.cache and method == 'GET':
cache_result = api.cache.get(url, timeout)
if cache_result:
# if cache result found and not expired, return it
if cache_result:
# must restore api reference
if isinstance(cache_result, list):
for result in cache_result:
......@@ -99,7 +100,7 @@ def bind_api(path, parser, allowed_param=None, method='GET', require_auth=False,
try:
error_msg = parse_error(resp.read())
except Exception:
error_msg = "Unkown twitter error response received: status=%s" % resp.status
error_msg = "Twitter error response: status code = %s" % resp.status
raise TweepError(error_msg)
# Pass returned body into parser and return parser output
......@@ -125,3 +126,4 @@ def bind_api(path, parser, allowed_param=None, method='GET', require_auth=False,
return out
return _call
......@@ -12,13 +12,13 @@ import fcntl
import cPickle as pickle
from . import memcache
from . error import TweepError
"""Cache interface"""
class Cache(object):
"""Cache interface"""
def __init__(self, timeout=60):
"""Init the cache
"""Initialize the cache
timeout: number of seconds to keep a cached entry
"""
self.timeout = timeout
......@@ -49,8 +49,9 @@ class Cache(object):
"""Delete all cached entries"""
raise NotImplementedError
"""In-memory cache"""
class MemoryCache(Cache):
"""In-memory cache"""
def __init__(self, timeout=60):
Cache.__init__(self, timeout)
......@@ -100,7 +101,7 @@ class MemoryCache(Cache):
def cleanup(self):
with self.lock:
for k,v in self._entries.items():
for k, v in self._entries.items():
if self._is_expired(v, self.timeout):
del self._entries[k]
......@@ -108,8 +109,9 @@ class MemoryCache(Cache):
with self.lock:
self._entries.clear()
"""File-based cache"""
class FileCache(Cache):
"""File-based cache"""
# locks used to make cache thread-safe
cache_locks = {}
......@@ -194,22 +196,26 @@ class FileCache(Cache):
def count(self):
c = 0
for entry in os.listdir(self.cache_dir):
if entry.endswith('.lock'): continue
if entry.endswith('.lock'):
continue
c += 1
return c
def cleanup(self):
for entry in os.listdir(self.cache_dir):
if entry.endswith('.lock'): continue
if entry.endswith('.lock'):
continue
self._get(os.path.join(self.cache_dir, entry), None)
def flush(self):
for entry in os.listdir(self.cache_dir):
if entry.endswith('.lock'): continue
if entry.endswith('.lock'):
continue
self._delete_file(os.path.join(self.cache_dir, entry))
"""Memcache client"""
class MemCache(Cache):
"""Memcache client"""
def __init__(self, servers, timeout=60):
Cache.__init__(self, timeout)
......
......@@ -2,13 +2,12 @@
# Copyright 2009 Joshua Roesslein
# See LICENSE
"""
Tweepy exception
"""
class TweepError(Exception):
"""Tweepy exception"""
def __init__(self, reason):
self.reason = str(reason)
def __str__(self):
return self.reason
......@@ -4,13 +4,16 @@
from . error import TweepError
class Model(object):
def __getstate__(self):
# pickle
pickle = {}
for k,v in self.__dict__.items():
if k == '_api': continue # do not pickle the api reference
for k, v in self.__dict__.items():
if k == '_api':
# do not pickle the api reference
continue
pickle[k] = v
return pickle
......@@ -21,11 +24,13 @@ class Model(object):
if not hasattr(model, attr):
missing.append(attr)
if len(missing) > 0:
raise TweepError('Missing required attribute(s) %s' % str(missing).strip('[]'))
raise TweepError('Missing required attribute(s) %s' % \
str(missing).strip('[]'))
def validate(self):
return
class Status(Model):
@staticmethod
......@@ -36,12 +41,14 @@ class Status(Model):
])
if hasattr(status, 'user'):
User._validate(status.user)
def validate(self):
Status._validate(self)
def destroy(self):
return self._api.destroy_status(id=self.id)
class User(Model):
@staticmethod
......@@ -49,45 +56,55 @@ class User(Model):
Model._validate(user, [
'id', 'name', 'screen_name', 'location', 'description', 'profile_image_url',
'url', 'protected', 'followers_count', 'profile_background_color',
'profile_text_color', 'profile_sidebar_fill_color', 'profile_sidebar_border_color',
'friends_count', 'created_at', 'favourites_count', 'utc_offset', 'time_zone',
'profile_background_image_url', 'statuses_count', 'notifications', 'following',
'verified'
'profile_text_color', 'profile_sidebar_fill_color',
'profile_sidebar_border_color', 'friends_count', 'created_at',
'favourites_count', 'utc_offset', 'time_zone',
'profile_background_image_url', 'statuses_count',
'notifications', 'following', 'verified'
])
if hasattr(user, 'status'):
Status._validate(user.status)
def validate(self):
User._validate(self)
def timeline(self, **kargs):
return self._api.user_timeline(**kargs)
def mentions(self, **kargs):
return self._api.mentions(**kargs)
def friends(self, **kargs):
return self._api.friends(id=self.id, **kargs)
def followers(self, **kargs):
return self._api.followers(id=self.id, **kargs)
def follow(self):
self._api.create_friendship(user_id=self.id)
self.following = True
def unfollow(self):
self._api.destroy_friendship(user_id=self.id)
self.following = False
class DirectMessage(Model):
def destroy(self):
return self._api.destroy_direct_message(id=self.id)
class Friendship(Model):
pass
class SavedSearch(Model):
pass
class SearchResult(Model):
pass
......
......@@ -11,43 +11,52 @@ try:
except ImportError:
import simplejson as json
def parse_json(data, api):
return json.loads(data)
def parse_return_true(data, api):
return True
def parse_none(data, api):
return None
def parse_error(data):
return json.loads(data)['error']
def _parse_datetime(str):
return datetime.strptime(str, '%a %b %d %H:%M:%S +0000 %Y')
def _parse_search_datetime(str):
return datetime.strptime(str, '%a, %d %b %Y %H:%M:%S +0000')
def _parse_html_value(html):
return html[html.find('>')+1:html.rfind('<')]
def _parse_a_href(atag):
return atag[atag.find('"')+1:atag.find('>')-1]
def _parse_user(obj, api):
user = models['user']()
user._api = api
for k,v in obj.items():
for k, v in obj.items():
if k == 'created_at':
setattr(user, k, _parse_datetime(v))
elif k == 'status':
......@@ -62,10 +71,12 @@ def _parse_user(obj, api):
setattr(user, k, v)
return user
def parse_user(data, api):
return _parse_user(json.loads(data), api)
def parse_users(data, api):
users = []
......@@ -73,11 +84,12 @@ def parse_users(data, api):
users.append(_parse_user(obj, api))
return users
def _parse_status(obj, api):
status = models['status']()
status._api = api
for k,v in obj.items():
for k, v in obj.items():
if k == 'user':
setattr(status, 'author', _parse_user(v, api))
elif k == 'created_at':
......@@ -89,10 +101,12 @@ def _parse_status(obj, api):
setattr(status, k, v)
return status
def parse_status(data, api):
return _parse_status(json.loads(data), api)
def parse_statuses(data, api):
statuses = []
......@@ -100,11 +114,12 @@ def parse_statuses(data, api):
statuses.append(_parse_status(obj, api))
return statuses
def _parse_dm(obj, api):
dm = models['direct_message']()
dm._api = api
for k,v in obj.items():
for k, v in obj.items():
if k == 'sender' or k == 'recipient':
setattr(dm, k, _parse_user(v, api))
elif k == 'created_at':
......@@ -113,10 +128,12 @@ def _parse_dm(obj, api):
setattr(dm, k, v)
return dm
def parse_dm(data, api):
return _parse_dm(json.loads(data), api)
def parse_directmessages(data, api):
directmessages = []
......@@ -124,37 +141,41 @@ def parse_directmessages(data, api):
directmessages.append(_parse_dm(obj, api))
return directmessages
def parse_friendship(data, api):
relationship = json.loads(data)['relationship']
# parse source
source = models['friendship']()
for k,v in relationship['source'].items():
for k, v in relationship['source'].items():
setattr(source, k, v)
# parse target
target = models['friendship']()
for k,v in relationship['target'].items():
for k, v in relationship['target'].items():
setattr(target, k, v)
return source, target