Fix prediction sample to comply with updated insert request and take data as command-line flag.
Reviewed in http://codereview.appspot.com/4535113/
Index: samples/prediction/README
===================================================================
new file mode 100644
diff --git a/samples/prediction/prediction.py b/samples/prediction/prediction.py
index 81f7a76..9daf31e 100644
--- a/samples/prediction/prediction.py
+++ b/samples/prediction/prediction.py
@@ -17,11 +17,12 @@
"""Simple command-line sample for the Google Prediction API
-Command-line application that trains on some data. This sample does
-the same thing as the Hello Prediction! example.
+Command-line application that trains on your input data. This sample does
+the same thing as the Hello Prediction! example. You might want to run
+the setup.sh script to load the sample data to Google Storage.
Usage:
- $ python prediction.py
+ $ python prediction.py --object_name="bucket/object"
You can also get help on all the command-line flags the program understands
by running:
@@ -71,6 +72,11 @@
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
'Set the level of logging detail.')
+gflags.DEFINE_string('object_name',
+ None,
+ 'Full Google Storage path of csv data (ex bucket/object)')
+
+gflags.MarkFlagAsRequired('object_name')
def main(argv):
# Let the gflags module process the command-line arguments
@@ -100,12 +106,10 @@
try:
- # Name of Google Storage bucket/object that contains the training data
- OBJECT_NAME = "apiclient-prediction-sample/prediction_models/languages"
-
# Start training on a data set
train = service.training()
- start = train.insert(data=OBJECT_NAME, body={}).execute()
+ body = {'id' : FLAGS.object_name}
+ start = train.insert(body=body).execute()
print 'Started training'
pprint.pprint(start)
@@ -113,7 +117,7 @@
import time
# Wait for the training to complete
while True:
- status = train.get(data=OBJECT_NAME).execute()
+ status = train.get(data=FLAGS.object_name).execute()
pprint.pprint(status)
if 'RUNNING' != status['trainingStatus']:
break
@@ -123,7 +127,7 @@
# Now make a prediction using that training
body = {'input': {'csvInstance': ["mucho bueno"]}}
- prediction = service.predict(body=body, data=OBJECT_NAME).execute()
+ prediction = service.predict(body=body, data=FLAGS.object_name).execute()
print 'The prediction is:'
pprint.pprint(prediction)